From 0fb7f791daede54fa798f02b795c5bb35b9654ea Mon Sep 17 00:00:00 2001 From: Guillaume Calmettes Date: Wed, 2 Oct 2024 18:36:46 +0200 Subject: [PATCH] feat: use MistralToolCall when using mistral tool parser --- .../tool_parsers/mistral_tool_parser.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 4b0e1c91df97c..b61ad40a697e4 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -1,9 +1,12 @@ import json import re +from random import choices +from string import ascii_letters, digits from typing import Dict, List, Sequence, Union import partial_json_parser from partial_json_parser.core.options import Allow +from pydantic import Field from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -19,6 +22,19 @@ logger = init_logger(__name__) +ALPHANUMERIC = ascii_letters + digits + + +class MistralToolCall(ToolCall): + id: str = Field( + default_factory=lambda: MistralToolCall.generate_random_id()) + + @staticmethod + def generate_random_id(): + # Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. + # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 + return "".join(choices(ALPHANUMERIC, k=9)) + class MistralToolParser(ToolParser): """ @@ -71,8 +87,8 @@ def extract_tool_calls(self, # load the JSON, and then use it to build the Function and # Tool Call function_call_arr = json.loads(raw_tool_call) - tool_calls: List[ToolCall] = [ - ToolCall( + tool_calls: List[MistralToolCall] = [ + MistralToolCall( type="function", function=FunctionCall( name=raw_function_call["name"],