Skip to content

Commit

Permalink
feat: use MistralToolCall when using mistral tool parser
Browse files Browse the repository at this point in the history
  • Loading branch information
gcalmettes committed Oct 2, 2024
1 parent 563649a commit 0fb7f79
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import re
from random import choices
from string import ascii_letters, digits
from typing import Dict, List, Sequence, Union

import partial_json_parser
from partial_json_parser.core.options import Allow
from pydantic import Field

from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -19,6 +22,19 @@

logger = init_logger(__name__)

ALPHANUMERIC = ascii_letters + digits


class MistralToolCall(ToolCall):
id: str = Field(
default_factory=lambda: MistralToolCall.generate_random_id())

@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))


class MistralToolParser(ToolParser):
"""
Expand Down Expand Up @@ -71,8 +87,8 @@ def extract_tool_calls(self,
# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: List[ToolCall] = [
ToolCall(
tool_calls: List[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
Expand Down

0 comments on commit 0fb7f79

Please sign in to comment.