Skip to content

Commit

Permalink
feat: use MistralToolCall when using mistral tool parser
Browse files Browse the repository at this point in the history
  • Loading branch information
gcalmettes committed Oct 2, 2024
1 parent 66985bd commit a8a57f4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
22 changes: 3 additions & 19 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from argparse import Namespace
from random import choices
from string import ascii_letters, digits
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
from typing import Any, Dict, List, Literal, Optional, Union

import torch
from openai.types.chat import ChatCompletionContentPartParam
Expand All @@ -23,8 +21,6 @@
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
_LONG_INFO: Union["torch.iinfo", Namespace]

ALPHANUMERIC = ascii_letters + digits

try:
from sphinx.ext.autodoc.mock import _MockModule

Expand Down Expand Up @@ -776,17 +772,6 @@ class ToolCall(OpenAIBaseModel):
function: FunctionCall


class MistralToolCall(ToolCall):
id: str = Field(
default_factory=lambda: MistralToolCall.generate_random_id())

@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))


class DeltaFunctionCall(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
Expand All @@ -805,7 +790,7 @@ class ExtractedToolCallInformation(BaseModel):
tools_called: bool

# extracted tool calls
tool_calls: Sequence[Union[ToolCall, MistralToolCall]]
tool_calls: List[ToolCall]

# content - per OpenAI spec, content AND tool calls can be returned rarely
# But some models will do this intentionally
Expand All @@ -815,8 +800,7 @@ class ExtractedToolCallInformation(BaseModel):
class ChatMessage(OpenAIBaseModel):
role: str
content: Optional[str] = None
tool_calls: Sequence[Union[ToolCall,
MistralToolCall]] = Field(default_factory=list)
tool_calls: List[ToolCall] = Field(default_factory=list)


class ChatCompletionLogProb(OpenAIBaseModel):
Expand Down
20 changes: 18 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import json
import re
from random import choices
from string import ascii_letters, digits
from typing import Dict, List, Sequence, Union

import partial_json_parser
from partial_json_parser.core.options import Allow
from pydantic import Field

from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, MistralToolCall)
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
from vllm.entrypoints.openai.tool_parsers.utils import (
Expand All @@ -19,6 +22,19 @@

logger = init_logger(__name__)

ALPHANUMERIC = ascii_letters + digits


class MistralToolCall(ToolCall):
id: str = Field(
default_factory=lambda: MistralToolCall.generate_random_id())

@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))


class MistralToolParser(ToolParser):
"""
Expand Down Expand Up @@ -71,7 +87,7 @@ def extract_tool_calls(self,
# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: Sequence[MistralToolCall] = [
tool_calls: List[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
Expand Down

0 comments on commit a8a57f4

Please sign in to comment.