Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Feature] support tool calling for internlm/internlm2_5-7b-chat model #8405

Merged
merged 25 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
87b6352
[add] add tools call for internlm2
sydnash Sep 12, 2024
5355659
Merge branch 'main' into add-internlm2-for-tool-use
sydnash Sep 12, 2024
68cd89d
[add] add some comments
sydnash Sep 12, 2024
d17f006
[add] add some comments
sydnash Sep 12, 2024
2d7d9d4
[fix] fix internlm2 tool chat template, fix the internlm2 tool call o…
sydnash Sep 13, 2024
12352e7
[add] add tool parser plugin doc
sydnash Sep 13, 2024
11bed0d
[add] add tool parser plugin doc
sydnash Sep 13, 2024
8a8b840
[fix] fix the stream tool call for internlm2
sydnash Sep 13, 2024
00c5da2
[fix] comment
sydnash Sep 13, 2024
882c764
[merge] resolve conflict
sydnash Sep 13, 2024
12b1035
[fix] use metavar to display the help info for --tool-call-parser, ad…
sydnash Sep 14, 2024
ed5b3fd
[add] got valid tool parsers from ToolParserManager
sydnash Sep 14, 2024
ea2c089
[fix] fix build for docs
sydnash Sep 14, 2024
36ad5d0
[fix] internlm's tool call out may arguments or parameters
sydnash Sep 15, 2024
cf981c0
[merge] resolve conflict
sydnash Sep 18, 2024
647db0d
refactor the tool parser to internlm, fix the test case of streamed_args
sydnash Sep 26, 2024
064ca1f
merge main
sydnash Sep 27, 2024
106909c
[fix] fix internlm parallel test, remove vllm/version.py
sydnash Sep 28, 2024
e242501
[format]
sydnash Sep 29, 2024
0a5ddf4
[format]
sydnash Sep 29, 2024
1db530d
[fix] fix the mistral tool call error. recover vllm/version.py and de…
sydnash Sep 29, 2024
dc94a22
[fix] change vocab property to get_vocab method in mistral_tool_parse…
sydnash Sep 29, 2024
3048233
Merge remote-tracking branch 'origin/main' into add-internlm2-for-too…
sydnash Sep 29, 2024
a2f938f
[fix] remove --tokenizer-mode mistral for mistral test. fix the syste…
sydnash Oct 3, 2024
4b619a2
[merge] merge from main
sydnash Oct 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ torch
py-cpuinfo
transformers
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
74 changes: 72 additions & 2 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ vLLM will use guided decoding to ensure the response matches the tool parameter
To enable this feature, you should set the following flags:
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
deems appropriate.
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes`, `mistral` or `llama3_json`. Additional tool parsers
will continue to be added in the future.
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `llama3_json` or `internlm`. Additional tool parsers
will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`.
* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`.
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
Expand Down Expand Up @@ -218,4 +219,73 @@ it works better with vLLM.

Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`

#### Internlm Models
Supported models:
* `internlm/internlm2_5-7b-chat` (confirmed)
* Additional internlm2.5 function-calling models are compatible as well

Known issues:
* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model.

Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja`


### How to write a tool parser plugin

A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.

Here is a summary of a plugin file:

```python

# import the required packages

# define a tool parser and register it to vllm
# the name list in register_module can be used
# in --tool-call-parser. you can define as many
# tool parsers as you want here.
@ToolParserManager.register_module(["example"])
class ExampleToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)

# adjust request. e.g.: set skip special tokens
# to False for tool call output.
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
return request

# implement the tool call parse for stream call
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
return delta

# implement the tool parse for non-stream call
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)


```

Then you can use this plugin in the command line like this.
```
--enable-auto-tool-choice \
--tool-parser-plugin <absolute path of the plugin file>
--tool-call-parser example \
--chat-template <your chat template> \
```

60 changes: 60 additions & 0 deletions examples/tool_chat_template_internlm2_tool.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}

{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}

{{- bos_token }}
{%- if system_message is defined %}
{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }}
{%- endif %}

{%- if tools is not none %}
{{- "<|im_start|>system name=<|plugin|>\n[" }}
{%- for tool in tools %}
{{- tool.function|tojson }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "<|im_end|>\n" }}
{%- endif %}

{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}}
{%- elif message.tool_calls is defined and message.tool_calls is not none %}
{%- set content = message["content"] if message["content"] else "" %}
{{- "<|im_start|>assistant\n" + content }}
{%- for tool_call in message.tool_calls %}
{%- set function=tool_call.function %}
{{- "<|action_start|><|plugin|>\n" }}
{{- '{"name": "' + function.name + '", '}}
{{- '"arguments": ' + function.arguments|tojson + '}' }}
{{- "<|action_end|>" }}
{%- endfor %}
{{- "<|im_end|>\n" }}
{%- elif message["role"] == "assistant" %}
{{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }}
{%- else %}
{{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}

{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
14 changes: 13 additions & 1 deletion tests/tool_use/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
},
"internlm": {
"model":
"internlm/internlm2_5-7b-chat",
"arguments": [
"--tool-call-parser", "internlm", "--chat-template",
str(VLLM_PATH /
"examples/tool_chat_template_internlm2_tool.jinja"),
"--trust_remote_code"
],
"supports_parallel":
False,
}
}

Expand All @@ -109,7 +121,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
"type":
"string",
"description":
"the two-letter abbreviation for the state "
"must the two-letter abbreviation for the state "
"that the city is in, e.g. 'CA' which would "
"mean 'California'"
},
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
Expand Down Expand Up @@ -526,6 +527,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)

valide_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")

temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
temp_socket.bind(("", args.port))

Expand Down
14 changes: 13 additions & 1 deletion vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser


Expand Down Expand Up @@ -190,16 +191,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use")

valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["mistral", "hermes", "llama3_json"],
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
"--tool-parser-plugin",
default=None,
help=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")

parser.add_argument(
"--tool-parser-plugin",
type=str,
default="",
help=
"Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")

sydnash marked this conversation as resolved.
Show resolved Hide resolved
parser = AsyncEngineArgs.add_cli_args(parser)

parser.add_argument('--max-log-len',
Expand Down
38 changes: 20 additions & 18 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
OpenAIServing,
PromptAdapterPath,
TextTokensPrompt)
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
Llama3JsonToolParser,
MistralToolParser,
ToolParser)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
Expand Down Expand Up @@ -82,15 +79,13 @@ def __init__(self,

self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools:
if tool_parser == "mistral":
self.tool_parser = MistralToolParser
elif tool_parser == "hermes":
self.tool_parser = Hermes2ProToolParser
elif tool_parser == "llama3_json":
self.tool_parser = Llama3JsonToolParser
else:
try:
self.tool_parser = ToolParserManager.get_tool_parser(
tool_parser)
except Exception as e:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
f"tool_parser:'{tool_parser}' which has not "
"been registered") from e

async def create_chat_completion(
self,
Expand Down Expand Up @@ -187,6 +182,10 @@ async def create_chat_completion(
raw_request.state.request_metadata = request_metadata

try:
if self.enable_auto_tools and self.tool_parser:
request = self.tool_parser(tokenizer).adjust_request(
request=request)

if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
Expand Down Expand Up @@ -282,11 +281,11 @@ async def chat_completion_stream_generator(
num_choices = 1 if request.n is None else request.n
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices

num_prompt_tokens = 0

tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None
tool_parsers: List[Optional[ToolParser]] = [
self.tool_parser(tokenizer) if self.tool_parser else None
] * num_choices
sydnash marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
Expand Down Expand Up @@ -324,7 +323,7 @@ async def chat_completion_stream_generator(
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):

tool_parser = tool_parsers[i]
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
Expand Down Expand Up @@ -399,6 +398,7 @@ async def chat_completion_stream_generator(

for output in res.outputs:
i = output.index
tool_parser = tool_parsers[i]

if finish_reason_sent[i]:
continue
Expand Down Expand Up @@ -446,7 +446,8 @@ async def chat_completion_stream_generator(
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids))
delta_token_ids=output.token_ids,
request=request))

# update the previous values for the next iteration
previous_texts[i] = current_text
Expand Down Expand Up @@ -685,7 +686,8 @@ async def chat_completion_full_generator(
and self.tool_parser:

tool_parser = self.tool_parser(tokenizer)
tool_call_info = tool_parser.extract_tool_calls(output.text)
tool_call_info = tool_parser.extract_tool_calls(
output.text, request=request)
tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
message = ChatMessage(role=role,
Expand Down
7 changes: 4 additions & 3 deletions vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .abstract_tool_parser import ToolParser
from .abstract_tool_parser import ToolParser, ToolParserManager
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser

__all__ = [
"ToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Llama3JsonToolParser"
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
]
Loading
Loading