From e02ce498be2e11a165803d4590588ba98f129797 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 4 Sep 2024 15:18:13 -0500 Subject: [PATCH] [Feature] OpenAI-Compatible Tools API + Streaming for Hermes & Mistral models (#5649) Co-authored-by: constellate Co-authored-by: Kyle Mistele --- .buildkite/test-pipeline.yaml | 10 + .../serving/openai_compatible_server.md | 58 ++- ...penai_chat_completion_client_with_tools.py | 162 +++++++++ examples/tool_chat_template_hermes.jinja | 129 +++++++ examples/tool_chat_template_mistral.jinja | 86 +++++ .../tool_chat_template_mistral_parallel.jinja | 94 +++++ requirements-common.txt | 1 + tests/tool_use/__init__.py | 0 tests/tool_use/conftest.py | 32 ++ tests/tool_use/test_chat_completions.py | 143 ++++++++ tests/tool_use/test_parallel_tool_calls.py | 193 ++++++++++ tests/tool_use/test_tool_calls.py | 192 ++++++++++ tests/tool_use/utils.py | 215 +++++++++++ vllm/entrypoints/chat_utils.py | 101 ++++- vllm/entrypoints/openai/api_server.py | 8 +- vllm/entrypoints/openai/cli_args.py | 18 + vllm/entrypoints/openai/protocol.py | 125 ++++++- vllm/entrypoints/openai/serving_chat.py | 275 ++++++++++++-- .../openai/serving_tokenization.py | 6 +- .../openai/tool_parsers/__init__.py | 5 + .../tool_parsers/abstract_tool_parser.py | 58 +++ .../openai/tool_parsers/hermes_tool_parser.py | 344 ++++++++++++++++++ .../tool_parsers/mistral_tool_parser.py | 293 +++++++++++++++ vllm/entrypoints/openai/tool_parsers/utils.py | 87 +++++ .../guided_decoding/__init__.py | 5 +- .../guided_decoding/outlines_decoding.py | 31 +- 26 files changed, 2588 insertions(+), 83 deletions(-) create mode 100644 examples/openai_chat_completion_client_with_tools.py create mode 100644 examples/tool_chat_template_hermes.jinja create mode 100644 examples/tool_chat_template_mistral.jinja create mode 100644 examples/tool_chat_template_mistral_parallel.jinja create mode 100644 tests/tool_use/__init__.py create mode 100644 tests/tool_use/conftest.py create mode 100644 tests/tool_use/test_chat_completions.py create mode 100644 tests/tool_use/test_parallel_tool_calls.py create mode 100644 tests/tool_use/test_tool_calls.py create mode 100644 tests/tool_use/utils.py create mode 100644 vllm/entrypoints/openai/tool_parsers/__init__.py create mode 100644 vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 65e1862ce8181..d50d8f32a816d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -92,6 +92,7 @@ steps: - pytest -v -s entrypoints/openai - pytest -v -s entrypoints/test_chat_utils.py + - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -271,6 +272,15 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: OpenAI-Compatible Tool Use # 20 min + fast_check: false + mirror_hardwares: [ amd ] + source_file_dependencies: + - vllm/ + - tests/tool_use + commands: + - pytest -v -s tool_use + ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index b2acde390083c..eb4ea0fb5655e 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -110,6 +110,14 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) :func: create_parser_for_docs :prog: vllm serve ``` +## Tool Calling in the Chat Completion API +### Named Function Calling +vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is +enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a +high-quality one. + +To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and +specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. ### Config file @@ -140,10 +148,52 @@ The order of priorities is `command line > config file values > defaults`. ## Tool calling in the chat completion API vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap. -To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter. - -It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.** +It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. -Please refer to the OpenAI API reference documentation for more information. + +### Automatic Function Calling +To enable this feature, you should set the following flags: +* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it +deems appropriate. +* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers +will continue to be added in the future. +* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages +that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their +`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat +template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) +from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json) + +If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! + +#### Hermes Models +All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported. +* `NousResearch/Hermes-2-Pro-*` +* `NousResearch/Hermes-2-Theta-*` +* `NousResearch/Hermes-3-*` + + +_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge +step in their creation_. + +Flags: `--tool-call-parser hermes` + +#### Mistral Models +Supported models: +* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed) +* Additional mistral function-calling models are compatible as well. + +Known issues: +1. Mistral 7B struggles to generate parallel tool calls correctly. +2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is +much shorter than what vLLM generates. Since an exception is thrown when this condition +is not met, the following additional chat templates are provided: + +* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that +it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) +* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt +when tools are provided, that results in much better reliability when working with parallel tool calling. + + +Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py new file mode 100644 index 0000000000000..2bbe42b6bd2ef --- /dev/null +++ b/examples/openai_chat_completion_client_with_tools.py @@ -0,0 +1,162 @@ +""" +Set up this example by starting a vLLM OpenAI-compatible server with tool call +options enabled. For example: + +IMPORTANT: for mistral, you must use one of the provided mistral tool call +templates, or your own - the model default doesn't work for tool calls with vLLM +See the vLLM docs on OpenAI server & tool calling for more details. + +vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \ + --chat-template examples/tool_chat_template_mistral.jinja \ + --enable-auto-tool-choice --tool-call-parser mistral + +OR +vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \ + --chat-template examples/tool_chat_template_hermes.jinja \ + --enable-auto-tool-choice --tool-call-parser hermes +""" +import json + +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] + +messages = [{ + "role": "user", + "content": "Hi! How are you doing today?" +}, { + "role": "assistant", + "content": "I'm doing well! How can I help you?" +}, { + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" +}] + +chat_completion = client.chat.completions.create(messages=messages, + model=model, + tools=tools) + +print("Chat completion results:") +print(chat_completion) +print("\n\n") + +tool_calls_stream = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=True) + +chunks = [] +for chunk in tool_calls_stream: + chunks.append(chunk) + if chunk.choices[0].delta.tool_calls: + print(chunk.choices[0].delta.tool_calls[0]) + else: + print(chunk.choices[0].delta) + +arguments = [] +tool_call_idx = -1 +for chunk in chunks: + + if chunk.choices[0].delta.tool_calls: + tool_call = chunk.choices[0].delta.tool_calls[0] + + if tool_call.index != tool_call_idx: + if tool_call_idx >= 0: + print( + f"streamed tool call arguments: {arguments[tool_call_idx]}" + ) + tool_call_idx = chunk.choices[0].delta.tool_calls[0].index + arguments.append("") + if tool_call.id: + print(f"streamed tool call id: {tool_call.id} ") + + if tool_call.function: + if tool_call.function.name: + print(f"streamed tool call name: {tool_call.function.name}") + + if tool_call.function.arguments: + arguments[tool_call_idx] += tool_call.function.arguments + +if len(arguments): + print(f"streamed tool call arguments: {arguments[-1]}") + +print("\n\n") + +messages.append({ + "role": "assistant", + "tool_calls": chat_completion.choices[0].message.tool_calls +}) + + +# Now, simulate a tool call +def get_current_weather(city: str, state: str, unit: 'str'): + return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " + "partly cloudly, with highs in the 90's.") + + +available_tools = {"get_current_weather": get_current_weather} + +completion_tool_calls = chat_completion.choices[0].message.tool_calls +for call in completion_tool_calls: + tool_to_call = available_tools[call.function.name] + args = json.loads(call.function.arguments) + result = tool_to_call(**args) + print(result) + messages.append({ + "role": "tool", + "content": result, + "tool_call_id": call.id, + "name": call.function.name + }) + +chat_completion_2 = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=False) +print("\n\n") +print(chat_completion_2) diff --git a/examples/tool_chat_template_hermes.jinja b/examples/tool_chat_template_hermes.jinja new file mode 100644 index 0000000000000..b18b463032d4f --- /dev/null +++ b/examples/tool_chat_template_hermes.jinja @@ -0,0 +1,129 @@ +{%- macro json_to_python_type(json_spec) %} + {%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + + {%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} + {%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]" }} + {%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }} + {%- else %} + {{- "dict" }} + {%- endif %} + {%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} + {%- else %} + {{- "Any" }} + {%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- if tools is iterable and tools | length > 0 %} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + "\n\n" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args:\n" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- "\n Returns:\n " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- "\n" }} + {%- endif %} + {%- endfor %} +{%- endif %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|>' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" and message.tool_calls is defined %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- '\n\n' }} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"}' }} + {{- ', ' }} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {{- tool_call.arguments|tojson }} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {%- if not loop.last %} + {{- '\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/examples/tool_chat_template_mistral.jinja b/examples/tool_chat_template_mistral.jinja new file mode 100644 index 0000000000000..49691f59c2f2c --- /dev/null +++ b/examples/tool_chat_template_mistral.jinja @@ -0,0 +1,86 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token }} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/examples/tool_chat_template_mistral_parallel.jinja b/examples/tool_chat_template_mistral_parallel.jinja new file mode 100644 index 0000000000000..a294cbfd026be --- /dev/null +++ b/examples/tool_chat_template_mistral_parallel.jinja @@ -0,0 +1,94 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- if tools is defined %} + {%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %} + {%- if system_message is defined %} + {%- set system_message = parallel_tool_prompt + "\n\n" + system_message %} + {%- else %} + {%- set system_message = parallel_tool_prompt %} + {%- endif %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token }} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/requirements-common.txt b/requirements-common.txt index 4c5b681a0d5ab..447fd32311c09 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,6 +20,7 @@ lm-format-enforcer == 0.10.6 outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +partial-json-parser # used for parsing partial JSON outputs pyzmq msgspec gguf == 0.9.1 diff --git a/tests/tool_use/__init__.py b/tests/tool_use/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py new file mode 100644 index 0000000000000..ab6a29eba1b3f --- /dev/null +++ b/tests/tool_use/conftest.py @@ -0,0 +1,32 @@ +import pytest +import pytest_asyncio +from huggingface_hub import snapshot_download + +from tests.utils import RemoteOpenAIServer + +from .utils import ARGS, CONFIGS, ServerConfig + + +# for each server config, download the model and return the config +@pytest.fixture(scope="session", params=CONFIGS.keys()) +def server_config(request): + config = CONFIGS[request.param] + # download model and tokenizer using transformers + snapshot_download(config["model"]) + yield CONFIGS[request.param] + + +# run this for each server config +@pytest.fixture(scope="session") +def server(request, server_config: ServerConfig): + model = server_config["model"] + args_for_model = server_config["arguments"] + with RemoteOpenAIServer(model, ARGS + args_for_model, + max_wait_seconds=480) as server: + yield server + + +@pytest_asyncio.fixture +async def client(server: RemoteOpenAIServer): + async with server.get_async_client() as async_client: + yield async_client diff --git a/tests/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py new file mode 100644 index 0000000000000..038ff81d2b674 --- /dev/null +++ b/tests/tool_use/test_chat_completions.py @@ -0,0 +1,143 @@ +from typing import List + +import openai +import pytest + +from .utils import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL + + +# test: make sure chat completions without tools provided work even when tools +# are enabled. This makes sure tool call chat templates work, AND that the tool +# parser stream processing doesn't change the output of the model. +@pytest.mark.asyncio +async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert len(output_text) > 0 + assert stop_reason != "tool_calls" + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + logprobs=False, + stream=True, + ) + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert not role_sent + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == output_text + + +# test: conversation with tools enabled and provided that should not invoke +# tools, to make sure we can still get normal chat completion responses +# and that they won't be parsed as tools +@pytest.mark.asyncio +async def test_chat_completion_with_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + tools=[WEATHER_TOOL], + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert stop_reason != 'tool_calls' + assert len(output_text) > 0 + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + logprobs=False, + tools=[WEATHER_TOOL], + stream=True, + ) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert chunk.choices[0].finish_reason != 'tool_calls' + assert len(chunks) + assert "".join(chunks) == output_text diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py new file mode 100644 index 0000000000000..b03b5a2075a6c --- /dev/null +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -0,0 +1,193 @@ +import json +from typing import Dict, List, Optional + +import openai +import pytest + +from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, + WEATHER_TOOL) + + +# test: getting the model to generate parallel tool calls (streaming/not) +# when requested. NOTE that not all models may support this, so some exclusions +# may be added in the future. e.g. llama 3.1 models are not designed to support +# parallel tool calls. +@pytest.mark.asyncio +async def test_parallel_tool_calls(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=200, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure 2 tool calls are present + assert choice.message.role == "assistant" + assert non_streamed_tool_calls is not None + assert len(non_streamed_tool_calls) == 2 + + for tool_call in non_streamed_tool_calls: + # make sure the tool includes a function and ID + assert tool_call.type == "function" + assert tool_call.function is not None + assert isinstance(tool_call.id, str) + assert len(tool_call.id) > 16 + + # make sure the weather tool was called correctly + assert tool_call.function.name == WEATHER_TOOL["function"]["name"] + assert isinstance(tool_call.function.arguments, str) + + parsed_arguments = json.loads(tool_call.function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + + assert stop_reason == "tool_calls" + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + role_name: Optional[str] = None + finish_reason_count: int = 0 + + tool_call_names: List[str] = [] + tool_call_args: List[str] = [] + tool_call_idx: int = -1 + tool_call_id_count: int = 0 + + async for chunk in stream: + + # if there's a finish reason make sure it's tools + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + tool_call_args.append("") + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + tool_call_id_count += 1 + assert (isinstance(tool_call.id, str) + and (len(tool_call.id) > 16)) + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + tool_call_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + tool_call_args[ + tool_call.index] += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == 'assistant' + + assert (len(non_streamed_tool_calls) == len(tool_call_names) == + len(tool_call_args)) + + for i in range(2): + assert non_streamed_tool_calls[i].function.name == tool_call_names[i] + streamed_args = json.loads(tool_call_args[i]) + non_streamed_args = json.loads( + non_streamed_tool_calls[i].function.arguments) + assert streamed_args == non_streamed_args + + +# test: providing parallel tool calls back to the model to get a response +# (streaming/not) +@pytest.mark.asyncio +async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=200, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # Dallas temp in tool response + assert "78" in choice.message.content # Orlando temp in tool response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=200, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py new file mode 100644 index 0000000000000..c3abe9e1f5060 --- /dev/null +++ b/tests/tool_use/test_tool_calls.py @@ -0,0 +1,192 @@ +import json +from typing import Dict, List, Optional + +import openai +import pytest + +from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, + SEARCH_TOOL, WEATHER_TOOL) + + +# test: request a chat completion that should return tool calls, so we know they +# are parsable +@pytest.mark.asyncio +async def test_tool_call_and_choice(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure a tool call is present + assert choice.message.role == 'assistant' + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].type == 'function' + assert tool_calls[0].function is not None + assert isinstance(tool_calls[0].id, str) + assert len(tool_calls[0].id) > 16 + + # make sure the weather tool was called (classic example) with arguments + assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"] + assert tool_calls[0].function.arguments is not None + assert isinstance(tool_calls[0].function.arguments, str) + + # make sure the arguments parse properly + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + assert parsed_arguments.get("city") == "Dallas" + assert parsed_arguments.get("state") == "TX" + + assert stop_reason == "tool_calls" + + function_name: Optional[str] = None + function_args_str: str = '' + tool_call_id: Optional[str] = None + role_name: Optional[str] = None + finish_reason_count: int = 0 + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=100, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + async for chunk in stream: + assert chunk.choices[0].index == 0 + + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + assert not tool_call_id + tool_call_id = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert function_name is None + assert isinstance(tool_call.function.name, str) + function_name = tool_call.function.name + if tool_call.function.arguments: + assert isinstance(tool_call.function.arguments, str) + function_args_str += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == 'assistant' + assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16) + + # validate the name and arguments + assert function_name == WEATHER_TOOL["function"]["name"] + assert function_name == tool_calls[0].function.name + assert isinstance(function_args_str, str) + + # validate arguments + streamed_args = json.loads(function_args_str) + assert isinstance(streamed_args, Dict) + assert isinstance(streamed_args.get("city"), str) + assert isinstance(streamed_args.get("state"), str) + assert streamed_args.get("city") == "Dallas" + assert streamed_args.get("state") == "TX" + + # make sure everything matches non-streaming except for ID + assert function_name == tool_calls[0].function.name + assert choice.message.role == role_name + assert choice.message.tool_calls[0].function.name == function_name + + # compare streamed with non-streamed args Dict-wise, not string-wise + # because character-to-character comparison might not work e.g. the tool + # call parser adding extra spaces or something like that. we care about the + # dicts matching not byte-wise match + assert parsed_arguments == streamed_args + + +# test: providing tools and results back to model to get a non-tool response +# (streaming/not) +@pytest.mark.asyncio +async def test_tool_call_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # the temperature from the response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py new file mode 100644 index 0000000000000..8ec9b05b2c521 --- /dev/null +++ b/tests/tool_use/utils.py @@ -0,0 +1,215 @@ +from typing import Dict, List + +from openai.types.chat import (ChatCompletionMessageParam, + ChatCompletionToolParam) +from typing_extensions import TypedDict + +from tests.utils import VLLM_PATH + + +class ServerConfig(TypedDict): + model: str + arguments: List[str] + + +# universal args for all models go here. also good if you need to test locally +# and change type or KV cache quantization or something. +ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] + +CONFIGS: Dict[str, ServerConfig] = { + "hermes": { + "model": + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "arguments": [ + "--tool-call-parser", "hermes", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") + ] + }, + "mistral": { + "model": + "mistralai/Mistral-7B-Instruct-v0.3", + "arguments": [ + "--tool-call-parser", "mistral", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), + "--ignore-patterns=\"consolidated.safetensors\"" + ] + } +} + +WEATHER_TOOL: ChatCompletionToolParam = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, " + "e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state " + "that the city is in, e.g. 'CA' which would " + "mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + } + } + } +} + +SEARCH_TOOL: ChatCompletionToolParam = { + "type": "function", + "function": { + "name": + "web_search", + "description": + "Search the internet and get a summary of the top " + "10 webpages. Should only be used if you don't know " + "the answer to a user query, and the results are likely" + "to be able to be found with a web search", + "parameters": { + "type": "object", + "properties": { + "search_term": { + "type": + "string", + "description": + "The term to use in the search. This should" + "ideally be keywords to search for, not a" + "natural-language question" + } + }, + "required": ["search_term"] + } + } +} + +MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "system", + "content": + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." +}, { + "role": + "user", + "content": + "Hi! How are you?" +}, { + "role": + "assistant", + "content": + "I'm doing great! How can I assist you?" +}, { + "role": + "user", + "content": + "Can you tell me a joke please?" +}] + +MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}] + +MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas is 98 degrees fahrenheit, with partly" + "cloudy skies and a low chance of rain." +}] + +MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}] + +MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }, { + "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Orlando", "state": "Fl", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas TX is 98 degrees fahrenheit with mostly " + "cloudy skies and a chance of rain in the evening." +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "content": + "The weather in Orlando FL is 78 degrees fahrenheit with clear" + "skies." +}] diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f205a99920892..9a7493649c795 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,23 +1,28 @@ import asyncio import codecs +import json from abc import ABC, abstractmethod from collections import defaultdict -from functools import lru_cache +from functools import lru_cache, partial from pathlib import Path from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal, - Mapping, Optional, Tuple, TypeVar, Union) + Mapping, Optional, Tuple, TypeVar, Union, cast) # yapf conflicts with isort for this block # yapf: disable -from openai.types.chat import ChatCompletionContentPartImageParam +from openai.types.chat import (ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) -from openai.types.chat import ChatCompletionContentPartTextParam +from openai.types.chat import (ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam) # yapf: enable # pydantic needs the TypedDict from typing_extensions -from pydantic import ConfigDict, TypeAdapter +from pydantic import ConfigDict from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig @@ -54,7 +59,8 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, - CustomChatCompletionContentPartParam, ] + ChatCompletionContentPartRefusalParam, + CustomChatCompletionContentPartParam] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -72,15 +78,33 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): same role. """ + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam] # TODO: Make fields ReadOnly once mypy supports it -class ConversationMessage(TypedDict): - role: str - content: str +class ConversationMessage(TypedDict, total=False): + role: Required[str] + """The role of the message's author.""" + + content: Optional[str] + """The contents of the message""" + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + name: Optional[str] + """The name of the function to call""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" ModalityStr = Literal["image", "audio"] @@ -319,9 +343,11 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], return "\n".join(missing_placeholders + [text_prompt]) -_TextParser = TypeAdapter(ChatCompletionContentPartTextParam) -_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam) -_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam) +# No need to validate using Pydantic again +_TextParser = partial(cast, ChatCompletionContentPartTextParam) +_ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) +_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) def _parse_chat_message_content_parts( @@ -336,10 +362,10 @@ def _parse_chat_message_content_parts( for part in parts: part_type = part["type"] if part_type == "text": - text = _TextParser.validate_python(part)["text"] + text = _TextParser(part)["text"] texts.append(text) elif part_type == "image_url": - image_url = _ImageParser.validate_python(part)["image_url"] + image_url = _ImageParser(part)["image_url"] if image_url.get("detail", "auto") != "auto": logger.warning( @@ -348,7 +374,7 @@ def _parse_chat_message_content_parts( mm_parser.parse_image(image_url["url"]) elif part_type == "audio_url": - audio_url = _AudioParser.validate_python(part)["audio_url"] + audio_url = _AudioParser(part)["audio_url"] mm_parser.parse_audio(audio_url["url"]) else: @@ -363,6 +389,11 @@ def _parse_chat_message_content_parts( return [ConversationMessage(role=role, content=text_prompt)] +# No need to validate using Pydantic again +_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) +_ToolParser = partial(cast, ChatCompletionToolMessageParam) + + def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, @@ -371,16 +402,34 @@ def _parse_chat_message_content( content = message.get("content") if content is None: - return [] - if isinstance(content, str): - return [ConversationMessage(role=role, content=content)] + content = [] + elif isinstance(content, str): + content = [ + ChatCompletionContentPartTextParam(type="text", text=content) + ] - return _parse_chat_message_content_parts( + result = _parse_chat_message_content_parts( role, content, # type: ignore mm_tracker, ) + for result_msg in result: + if role == 'assistant': + parsed_msg = _AssistantParser(message) + + if "tool_calls" in parsed_msg: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + def parse_chat_messages( messages: List[ChatCompletionMessageParam], @@ -428,6 +477,20 @@ def apply_chat_template( "allowed, so you must provide a chat template if the tokenizer " "does not define one.") + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in conversation: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for i in range(len(message["tool_calls"])): + args: str = message["tool_calls"][i]["function"]["arguments"] + parsed_args: Dict = json.loads(args) + message["tool_calls"][i]["function"]["arguments"] = parsed_args + prompt = tokenizer.apply_chat_template( conversation=conversation, chat_template=chat_template, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7632e8aa5e32e..728a2e5232d9b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -233,7 +233,7 @@ def mount_metrics(app: FastAPI): metrics_route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile('^/metrics(?P.*)$') + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") app.routes.append(metrics_route) @@ -283,11 +283,14 @@ async def show_version(): @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + generator = await openai_serving_chat.create_chat_completion( request, raw_request) + if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) + elif isinstance(generator, ChatCompletionResponse): return JSONResponse(content=generator.model_dump()) @@ -422,7 +425,8 @@ async def init_app( request_logger=request_logger, chat_template=args.chat_template, return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser) openai_serving_completion = OpenAIServingCompletion( async_engine_client, model_config, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 94742838b421c..7ccee0b6b55b7 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -163,6 +163,24 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="If specified, will run the OpenAI frontend server in the same " "process as the model serving engine.") + parser.add_argument( + "--enable-auto-tool-choice", + action="store_true", + default=False, + help= + "Enable auto tool choice for supported models. Use --tool-call-parser" + "to specify which parser to use") + + parser.add_argument( + "--tool-call-parser", + type=str, + choices=["mistral", "hermes"], + default=None, + help= + "Select the tool call parser depending on the model that you're using." + " This is used to parse the model-generated tool call into OpenAI API " + "format. Required for --enable-auto-tool-choice.") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0954b81595ef5..ff9c3690672b6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,8 +5,9 @@ from typing import Any, Dict, List, Literal, Optional, Union import torch +from openai.types.chat import ChatCompletionContentPartParam from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Annotated +from typing_extensions import Annotated, Required, TypedDict from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors @@ -35,6 +36,26 @@ assert _LONG_INFO.max == _MOCK_LONG_INFO.max +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + + tool_calls: Optional[List[dict]] + + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields model_config = ConfigDict(extra="forbid") @@ -145,8 +166,11 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 tools: Optional[List[ChatCompletionToolsParam]] = None - tool_choice: Optional[Union[Literal["none"], + tool_choice: Optional[Union[Literal["none"], Literal["auto"], ChatCompletionNamedToolChoiceParam]] = "none" + + # NOTE this will be ignored by VLLM -- the model determines the behavior + parallel_tool_calls: Optional[bool] = False user: Optional[str] = None # doc: begin-chat-completion-sampling-params @@ -328,6 +352,9 @@ def check_logprobs(cls, data): @model_validator(mode="before") @classmethod def check_guided_decoding_count(cls, data): + if isinstance(data, ValueError): + raise data + guide_count = sum([ "guided_json" in data and data["guided_json"] is not None, "guided_regex" in data and data["guided_regex"] is not None, @@ -339,21 +366,61 @@ def check_guided_decoding_count(cls, data): "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice').") # you can only either use guided decoding or tools, not both - if guide_count > 1 and "tool_choice" in data and data[ - "tool_choice"] != "none": + if guide_count > 1 and data.get("tool_choice", + "none") not in ("none", "auto"): raise ValueError( "You can only either use guided decoding or tools, not both.") return data @model_validator(mode="before") @classmethod - def check_tool_choice(cls, data): - if "tool_choice" in data and data["tool_choice"] != "none": - if not isinstance(data["tool_choice"], dict): - raise ValueError("Currently only named tools are supported.") + def check_tool_usage(cls, data): + + # if "tool_choice" is not specified but tools are provided, + # default to "auto" tool_choice + if "tool_choice" not in data and "tools" in data: + data["tool_choice"] = "auto" + + # if "tool_choice" is specified -- validation + if "tool_choice" in data: + + # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: raise ValueError( "When using `tool_choice`, `tools` must be set.") + + # make sure that tool choice is either a named tool + # OR that it's set to "auto" + if data["tool_choice"] != "auto" and not isinstance( + data["tool_choice"], dict): + raise ValueError( + "`tool_choice` must either be a named tool or \"auto\". " + "`tool_choice=\"none\" is not supported.") + + # ensure that if "tool_choice" is specified as an object, + # it matches a valid tool + if isinstance(data["tool_choice"], dict): + valid_tool = False + specified_function = data["tool_choice"]["function"] + if not specified_function: + raise ValueError( + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\"," + " \"function\": {\"name\": \"my_function\"}}`") + specified_function_name = specified_function["name"] + if not specified_function_name: + raise ValueError( + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\", " + "\"function\": {\"name\": \"my_function\"}}`") + for tool in data["tools"]: + if tool["function"]["name"] == specified_function_name: + valid_tool = True + break + if not valid_tool: + raise ValueError( + "The tool specified in `tool_choice` does not match any" + " of the specified `tools`") return data @@ -413,7 +480,7 @@ class CompletionRequest(OpenAIBaseModel): ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, - description=("If specified, the output will follow the JSON schema."), + description="If specified, the output will follow the JSON schema.", ) guided_regex: Optional[str] = Field( default=None, @@ -633,9 +700,41 @@ class ToolCall(OpenAIBaseModel): function: FunctionCall +class DeltaFunctionCall(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +# a tool call delta where everything is optional +class DeltaToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int + function: Optional[DeltaFunctionCall] = None + + +# the initial delta that gets sent once a new tool call is started; +class InitialDeltaToolCall(DeltaToolCall): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int + + +class ExtractedToolCallInformation(BaseModel): + # indicate if tools were called + tools_called: bool + + # extracted tool calls + tool_calls: List[ToolCall] + + # content - per OpenAI spec, content AND tool calls can be returned rarely + # But some models will do this intentionally + content: Optional[str] = None + + class ChatMessage(OpenAIBaseModel): role: str - content: str + content: Optional[str] = None tool_calls: List[ToolCall] = Field(default_factory=list) @@ -657,7 +756,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[str] = None + # per OpenAI spec this is the default + finish_reason: Optional[str] = "stop" + # not part of the OpenAI spec but included in vLLM for legacy reasons stop_reason: Optional[Union[int, str]] = None @@ -674,7 +775,7 @@ class ChatCompletionResponse(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None - tool_calls: List[ToolCall] = Field(default_factory=list) + tool_calls: List[DeltaToolCall] = Field(default_factory=list) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a3bc0bb7b3554..78f355228012f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,6 +1,8 @@ import asyncio +import json import time -from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional +from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List, + Optional) from typing import Sequence as GenericSequence from typing import Union @@ -18,15 +20,18 @@ ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - FunctionCall, ToolCall, UsageInfo) + ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath, TextTokensPrompt) +from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser, + MistralToolParser, + ToolParser) from vllm.inputs import TokensPrompt from vllm.logger import init_logger -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) @@ -38,19 +43,19 @@ class OpenAIServingChat(OpenAIServing): - def __init__( - self, - async_engine_client: AsyncEngineClient, - model_config: ModelConfig, - served_model_names: List[str], - response_role: str, - *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], - request_logger: Optional[RequestLogger], - chat_template: Optional[str], - return_tokens_as_token_ids: bool = False, - ): + def __init__(self, + async_engine_client: AsyncEngineClient, + model_config: ModelConfig, + served_model_names: List[str], + response_role: str, + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + return_tokens_as_token_ids: bool = False, + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None): super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, @@ -60,10 +65,27 @@ def __init__( return_tokens_as_token_ids=return_tokens_as_token_ids) self.response_role = response_role - - # If this is None we use the tokenizer's default chat template + self.use_tool_use_model_template = False self.chat_template = load_chat_template(chat_template) + # set up tool use + self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info( + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") + + self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None + if self.enable_auto_tools: + if tool_parser == "mistral": + self.tool_parser = MistralToolParser + elif tool_parser == "hermes": + self.tool_parser = Hermes2ProToolParser + else: + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -76,11 +98,10 @@ async def create_chat_completion( for the API specification. This API mimics the OpenAI ChatCompletion API. - NOTE: Currently we do not support the following feature: - - function_call (Users should implement this by themselves) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) return error_check_ret try: @@ -119,6 +140,20 @@ async def create_chat_completion( logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) + # validation for OpenAI tools + # tool_choice = "required" is not supported + if request.tool_choice == "required": + return self.create_error_response( + "tool_choice = \"required\" is not supported!") + + # "auto" tools requires --enable-auto-tool-choice + # and --tool-call-parser + if request.tool_choice == "auto" and not ( + self.enable_auto_tools and self.tool_parser is not None): + return self.create_error_response( + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-call-parser to be set") + request_id = f"chat-{random_uuid()}" try: guided_decode_logits_processor = ( @@ -187,6 +222,7 @@ async def create_chat_completion( if request.stream: return self.chat_completion_stream_generator( request, result_generator, request_id, conversation, tokenizer) + try: return await self.chat_completion_full_generator( request, result_generator, request_id, conversation, tokenizer) @@ -219,6 +255,9 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices + tool_parser: Optional[ToolParser] = self.tool_parser( + tokenizer) if self.tool_parser else None + try: async for res in result_generator: # We need to do it here, because if there are exceptions in @@ -228,6 +267,9 @@ async def chat_completion_stream_generator( # Send first response for each request.n (index) with # the role role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request for i in range(num_choices): choice_data = ChatCompletionResponseStreamChoice( index=i, @@ -240,14 +282,18 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) + + # if usage should be included if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): + # if continuous usage stats are requested, add it + if request.stream_options.continuous_usage_stats: prompt_tokens = len(res.prompt_token_ids) usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens) chunk.usage = usage + # otherwise don't else: chunk.usage = None @@ -257,7 +303,7 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message if request.echo: - last_msg_content = "" + last_msg_content: Optional[str] = "" if conversation and conversation[-1].get( "content") and conversation[-1].get( "role") == role: @@ -298,6 +344,7 @@ async def chat_completion_stream_generator( first_iteration = False for output in res.outputs: + i = output.index if finish_reason_sent[i]: @@ -320,20 +367,50 @@ async def chat_completion_stream_generator( logprobs = None delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + delta_message: Optional[DeltaMessage] = None - if request.tool_choice and type( - request.tool_choice - ) is ChatCompletionNamedToolChoiceParam: + # handle streaming deltas for tools with named tool_choice + if (request.tool_choice and type(request.tool_choice) is + ChatCompletionNamedToolChoiceParam): delta_message = DeltaMessage(tool_calls=[ - ToolCall(function=FunctionCall( + DeltaToolCall(function=DeltaFunctionCall( name=request.tool_choice.function.name, - arguments=delta_text)) + arguments=delta_text), + index=i) ]) + + # handle streaming deltas for tools with "auto" tool choice + elif (self._should_stream_with_auto_tool_parsing(request) + and tool_parser): + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_texts[i], + current_text=output.text, + delta_text=delta_text, + previous_token_ids= \ + output.token_ids[ + :-1 * len(delta_token_ids) + ], + current_token_ids=output.token_ids, + delta_token_ids=delta_token_ids + ) + ) + + # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) + # set the previous values for the next iteration + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + continue + if output.finish_reason is None: # Send token-by-token response for each request.n @@ -348,6 +425,8 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) + + # handle usage stats if requested & if continuous if (request.stream_options and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats): @@ -365,14 +444,55 @@ async def chat_completion_stream_generator( data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" + + # if the model is finished generating else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using guided decoding + if tool_parser: + index = len( + tool_parser.prev_tool_call_arr) - 1 if len( + tool_parser.prev_tool_call_arr) > 0 else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get( + "arguments", {})) + + # get what we've streamed so for for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[ + index] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace( + actual_call, "", 1) + + # set that as a delta message + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) + ]) + # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=output.finish_reason, + finish_reason=output.finish_reason + if not (tool_parser + and len(tool_parser.prev_tool_call_arr)) + else "tool_calls", stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, @@ -398,6 +518,8 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" finish_reason_sent[i] = True + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage if (request.stream_options and request.stream_options.include_usage): final_usage = UsageInfo( @@ -419,6 +541,7 @@ async def chat_completion_stream_generator( except ValueError as e: # TODO: Use a vllm-specific Validation Error + logger.error("error in chat completion stream generator: %s", e) data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished @@ -463,8 +586,21 @@ async def chat_completion_full_generator( else: logprobs = None - if request.tool_choice and type( + # by default, tools are not used. + tools_called = False + + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if not (self.enable_auto_tools + or not self.tool_parser) and not isinstance( + request.tool_choice, + ChatCompletionNamedToolChoiceParam): + message = ChatMessage(role=role, content=output.text) + + # if the request uses tools and specified a tool choice + elif request.tool_choice and type( request.tool_choice) is ChatCompletionNamedToolChoiceParam: + message = ChatMessage( role=role, content="", @@ -473,14 +609,47 @@ async def chat_completion_full_generator( name=request.tool_choice.function.name, arguments=output.text)) ]) + tools_called = True + + # if the request doesn't use tool choice + # OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": + + message = ChatMessage(role=role, content=output.text) + + # handle when there are tools and tool choice is auto + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: + + tool_parser = self.tool_parser(tokenizer) + tool_call_info = tool_parser.extract_tool_calls(output.text) + tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage(role=role, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + message = ChatMessage(role=role, content=output.text) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion.") message = ChatMessage(role=role, content=output.text) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason=output.finish_reason, + finish_reason="tool_calls" if tools_called else + output.finish_reason if output.finish_reason else "stop", stop_reason=output.stop_reason) choices.append(choice_data) @@ -488,10 +657,11 @@ async def chat_completion_full_generator( last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get("role") == role: - last_msg_content = conversation[-1]["content"] + last_msg_content = conversation[-1]["content"] or "" for choice in choices: - full_message = last_msg_content + choice.message.content + full_message = last_msg_content + (choice.message.content + or "") choice.message.content = full_message num_prompt_tokens = len(final_res.prompt_token_ids) @@ -574,3 +744,38 @@ def _create_chat_logprobs( )) return ChatCompletionLogProbs(content=logprobs_content) + + def _should_stream_with_auto_tool_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the tool + call parser that was configured. + + We only want to do this IF user-provided tools are set, a tool parser + is configured, "auto" tool choice is enabled, and the request's tool + choice field indicates that "auto" tool choice should be used. + """ + return (request.tools and self.tool_parser and self.enable_auto_tools + and request.tool_choice in ['auto', None]) + + def _should_check_for_unstreamed_tool_arg_tokens( + self, + delta_message: Optional[DeltaMessage], + output: CompletionOutput, + ) -> bool: + """ + Check to see if we should check for unstreamed tool arguments tokens. + This is only applicable when auto tool parsing is enabled, the delta + is a tool call with arguments. + """ + + # yapf: disable + return bool( + # if there is a delta message that includes tool calls which + # include a function that has arguments + self.enable_auto_tools and self.tool_parser and delta_message + and delta_message.tool_calls and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function + and delta_message.tool_calls[0].function.arguments is not None + and output.finish_reason is not None + ) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index c3c0d52072cd3..69a5ad5b62cfa 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -43,7 +43,11 @@ def __init__( request_logger=request_logger) # If this is None we use the tokenizer's default chat template - self.chat_template = load_chat_template(chat_template) + # the list of commonly-used chat template names for HF named templates + hf_chat_templates: List[str] = ['default', 'tool_use'] + self.chat_template = chat_template \ + if chat_template in hf_chat_templates \ + else load_chat_template(chat_template) async def create_tokenize( self, diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py new file mode 100644 index 0000000000000..5d5d53784fedf --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -0,0 +1,5 @@ +from .abstract_tool_parser import ToolParser +from .hermes_tool_parser import Hermes2ProToolParser +from .mistral_tool_parser import MistralToolParser + +__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"] \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py new file mode 100644 index 0000000000000..b0807e6f1e782 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -0,0 +1,58 @@ +from typing import Dict, List, Sequence, Union + +from vllm.entrypoints.openai.protocol import (DeltaMessage, + ExtractedToolCallInformation) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +class ToolParser: + """ + Abstract ToolParser class that should not be used directly. Provided + properties and methods should be used in + derived classes. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.prev_tool_call_arr: List[Dict] = [] + # the index of the tool call that is currently being parsed + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [] + + self.model_tokenizer = tokenizer + + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: + """ + Static method that should be implemented for extracting tool calls from + a complete model-generated string. + Used for non-streaming responses where we have the entire model response + available before sending to the client. + Static because it's stateless. + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls has not been implemented!") + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting tool calls + from an incomplete response; for use when handling tool calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls_streaming has not been " + "implemented!") diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py new file mode 100644 index 0000000000000..7afbca7162edf --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -0,0 +1,344 @@ +import json +import re +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + InitialDeltaToolCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + + +class Hermes2ProToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + logger.error( + "Detected Mistral tokenizer when using a Hermes model") + self.model_tokenizer = self.model_tokenizer.tokenizer + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL) + self.scratch_pad_regex = re.compile( + r"(.*?)", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ + self.tool_call_start_token] + self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ + self.tool_call_end_token] + if not self.tool_call_start_token_id or not self.tool_call_end_token_id: + raise RuntimeError( + "Hermes 2 Pro Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = ( + self.tool_call_regex.findall(model_output)) + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = [ + json.loads(match[0] if match[0] else match[1]) + for match in function_call_tuples + ] + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_call_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", + e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count): + logger.debug("Generating text content! skipping tool parsing.") + if delta_text != self.tool_call_end_token: + return DeltaMessage(content=delta_text) + + # case: if tool open & close tag counts don't match, we're doing + # imaginary "else" block here + # something with tools with this diff. + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count > prev_tool_end_count): + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[self.current_tool_id], "") + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", diff) + self.streamed_args_for_tool[self.current_tool_id] \ + += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + try: + + current_tool_call = partial_json_parser.loads( + tool_call_portion or "{}", + flags) if tool_call_portion else None + logger.debug("Parsed tool call %s", current_tool_call) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # case - we haven't sent the initial delta with the tool call ID + # (it will be sent) + if not self.current_tool_initial_sent: + self.current_tool_initial_sent = True + return DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + elif not self.current_tool_name_sent: + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + else: + return None + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage(content=delta_text) \ + if text_portion is not None else None + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = ( + self.prev_tool_call_arr[self.current_tool_id].get("arguments")) + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", delta_text, + cur_arguments_json) + + # get the location where previous args differ from current + args_delta_start_loc = cur_arguments_json.index(delta_text) \ + + len(delta_text) + + # use that to find the actual delta + arguments_delta = cur_arguments_json[:args_delta_start_loc] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += arguments_delta + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between\n%s", cur_args_json) + logger.debug("and\n%s", prev_args_json) + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got argument diff %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += argument_diff + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = \ + current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py new file mode 100644 index 0000000000000..d48770c792e98 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -0,0 +1,293 @@ +import json +import re +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + InitialDeltaToolCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + + +class MistralToolParser(ToolParser): + """ + Tool call parser for Mistral 7B Instruct v0.3, intended for use with the + examples/tool_chat_template_mistral.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + self.model_tokenizer = self.model_tokenizer.tokenizer + else: + logger.info("Non-Mistral tokenizer detected when using a Mistral " + "model...") + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "[TOOL_CALLS]" + self.bot_token_id = self.model_tokenizer.vocab[self.bot_token] + self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. Requires + find-and-replacing single quotes with double quotes for JSON parsing, + make sure your tool call arguments don't ever include quotes! + """ + + # case -- if a tool call token is not present, return a text response + if self.bot_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + try: + + # use a regex to find the tool call. remove the BOT token + # and make sure to replace single quotes with double quotes + raw_tool_call = self.tool_call_regex.findall( + model_output.replace(self.bot_token, ""))[0] + + # load the JSON, and then use it to build the Function and + # Tool Call + function_call_arr = json.loads(raw_tool_call) + tool_calls: List[ToolCall] = [ + ToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"]))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + content = model_output.split(self.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None) + + except Exception as e: + logger.error("Error in extracting tool call from response: %s", e) + print("ERROR", e) + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.bot_token_id not in current_token_ids: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + + # handle if we detected the BOT token which means the start of tool + # calling + if (self.bot_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion any don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # replace BOT token with empty string, and convert single quotes + # to double to allow parsing as JSON since mistral uses single + # quotes instead of double for tool calls + parsable_arr = current_text.split(self.bot_token)[1] + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + try: + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + + current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, None] = current_tool_call.get("arguments") + + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[self.current_tool_id], + "") + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + + # if the current tool initial data incl. the id, type=function + # and idx not sent, send that + if not self.current_tool_initial_sent: + self.current_tool_initial_sent = True + delta = DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", new_text, + cur_arguments_json) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py new file mode 100644 index 0000000000000..db7fc5259fc4e --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -0,0 +1,87 @@ +def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' + """ + prefix = '' + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ + suffix = '' + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(curr: str, old: str) -> str: + """ + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + + """ + suffix = find_common_suffix(curr, old) + + old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + + if len(prefix): + # replace the prefix only once in case it's mirrored + diff = diff.replace(prefix, '', 1) + + return diff + + +def find_all_indices(string, substring): + """ + Find all (starting) indices of a substring in a given string. Useful for + tool call extraction + """ + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index f9fcdead980a2..7161e83952a3d 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -59,8 +59,9 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, if type(request) is CompletionRequest: return request - # user has chosen to not use any tool - if request.tool_choice == "none": + # user has chosen to not use any tool, + # OR is allowing the model to choose a tool. + if request.tool_choice == "none" or request.tool_choice == "auto": return request # user has chosen to use a named tool diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index bfc658ef7d26b..e1f5b380120c5 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -8,8 +8,9 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, + CompletionRequest) from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( @@ -101,16 +102,30 @@ def _get_guide_and_mode( request: Union[CompletionRequest, ChatCompletionRequest, GuidedDecodingRequest] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: + # if the request is a chat completion request, AND the tool choice is a + # named tool choice, do guided decoding + # using that tool as the JSON schema + if isinstance(request, ChatCompletionRequest) and isinstance( + request.tool_choice, ChatCompletionNamedToolChoiceParam): + # Guided generation for tools/functions parameters + if request.tool_choice.type == "function": + for tool in request.tools: + if (tool.type == "function" and tool.function.name + == request.tool_choice.function.name): + json = json_dumps(tool.function.parameters, sort_keys=True) + return json, GuidedDecodingMode.JSON + return None, None - if request.guided_json: - json = request.guided_json - if isinstance(json, dict): + elif request.guided_json: + if isinstance(request.guided_json, dict): # turn dict into hashable string - json = json_dumps(json) - elif isinstance(json, BaseModel): + json = json_dumps(request.guided_json) + elif isinstance(request.guided_json, BaseModel): # use pydantic signature so that different model classes # with the same fields will get hashed the same - json = str(json.__signature__) + json = str(request.guided_json.__signature__) + else: + json = request.guided_json return json, GuidedDecodingMode.JSON elif request.guided_regex: return request.guided_regex, GuidedDecodingMode.REGEX