Skip to content

Commit

Permalink
feat: mixtral tool calling
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed May 1, 2024
1 parent 34fc170 commit ea338df
Showing 1 changed file with 70 additions and 1 deletion.
71 changes: 70 additions & 1 deletion kani/prompts/impl/mistral.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import abc
import json
import re

from kani.ai_function import AIFunction
from kani.models import ChatMessage, ChatRole, ToolCall
from kani.engines import BaseEngine, Completion
from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall
from kani.prompts import ApplyContext, PromptPipeline


Expand Down Expand Up @@ -91,6 +94,7 @@ def ensure_available_tools(msgs: list[ChatMessage], functions: list[AIFunction])
)
)


# tool use template
# {{bos_token}}
# {% set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}
Expand All @@ -112,3 +116,68 @@ def ensure_available_tools(msgs: list[ChatMessage], functions: list[AIFunction])
# {{'[TOOL_CALLS]' + message['content']|string + eos_token}}
# {% endif %}
# {% endfor %}"


# ==== function call parsing ====
# [TOOL_CALLS][{'name': 'get_current_weather', 'arguments': {'location': 'Paris, France', 'format': 'celsius'}}]</s>
class MixtralFunctionCallingMixin(BaseEngine, abc.ABC):
"""Common Mixtral-8x22B function calling parsing mixin."""

@staticmethod
def _parse_tool_calls(content: str) -> tuple[str, list[ToolCall]]:
tool_json = re.search(r"\[TOOL_CALLS](.+)</s>", content, re.IGNORECASE | re.DOTALL)
if tool_json is None:
return content, []
actions = json.loads(tool_json.group(1))

# translate back to kani spec
tool_calls = []
for action in actions:
tool_name = action["name"]
tool_args = json.dumps(action["arguments"])
tool_id = action.get("id")
tool_call = ToolCall.from_function_call(FunctionCall(name=tool_name, arguments=tool_args), call_id_=tool_id)
tool_calls.append(tool_call)

# return trimmed content and tool calls
return content[: tool_json.start()], tool_calls

async def predict(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams):
completion = await super().predict(messages, functions, **hyperparams)

# if we have tools, parse
if functions:
completion.message.content, completion.message.tool_calls = self._parse_tool_calls(completion.message.text)

return completion

async def stream(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams):
content_parts = []
in_tool_call = False
inner_completion = None

# consume from the inner iterator, yielding as normal until we see a tool call or a completion
async for elem in super().stream(messages, functions, **hyperparams):
if isinstance(elem, str):
content_parts.append(elem)
# if we see the start of a tool call, stop yielding and start buffering
if elem == "[TOOL_CALLS]":
in_tool_call = True
# otherwise yield the string
if not in_tool_call:
yield elem
else:
# save the inner completion
inner_completion = elem

# we have consumed all the elements - construct a new completion
# if we don't have a tool call we can just yield the inner completion
if not in_tool_call and inner_completion:
yield inner_completion
# otherwise, parse tool calls from the content (preserving inner tool calls if necessary)
else:
content = "".join(content_parts)
content, tool_calls = self._parse_tool_calls(content)
if inner_completion:
tool_calls = (inner_completion.message.tool_calls or []) + tool_calls
yield Completion(ChatMessage.assistant(content, tool_calls=tool_calls))

0 comments on commit ea338df

Please sign in to comment.