Skip to content

Commit

Permalink
fix(mistral): v3 function calling fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed May 28, 2024
1 parent 0bee873 commit e4d4d22
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions kani/prompts/impl/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _fmt_functions(functions: list[AIFunction]) -> str:
}
for f in functions
]
return f"[AVAILABLE_TOOLS]{tools_json}[/AVAILABLE_TOOLS]"
return f"[AVAILABLE_TOOLS] {tools_json}[/AVAILABLE_TOOLS]"


def fmt_available_tools(msg: ChatMessage, ctx: ApplyContext) -> ChatMessage:
Expand Down Expand Up @@ -68,11 +68,11 @@ def ensure_available_tools(msgs: list[ChatMessage], functions: list[AIFunction])
# generations).
.merge_consecutive(role=ChatRole.ASSISTANT, sep=" ")
# We wrap USER messages here since we do some shenanigans in the next step
.wrap(role=ChatRole.USER, prefix="[INST]", suffix="[/INST]")
.wrap(role=ChatRole.USER, prefix="[INST] ", suffix="[/INST]")
# --- function calling ---
.ensure_bound_function_calls()
.ensure_bound_function_calls(id_translator=lambda x: x.replace("-", "")[:9])
# Format function calls with the [TOOL_CALLS] format.
.function_call_fmt(json_tool_call, prefix="[TOOL_CALLS][", sep=",", suffix="]</s>")
.function_call_fmt(json_tool_call, prefix="[TOOL_CALLS] [", sep=",", suffix="]")
# Include the call ID in the FUNCTION result.
.apply(fmt_function_call_result, role=ChatRole.FUNCTION)
# Include the list of available functions just before the last user message
Expand All @@ -86,9 +86,9 @@ def ensure_available_tools(msgs: list[ChatMessage], functions: list[AIFunction])
.conversation_fmt(
prefix="<s>",
assistant_prefix=" ",
assistant_suffix=" </s>",
assistant_suffix="</s>",
assistant_suffix_if_last="",
function_prefix="[TOOL_RESULTS]",
function_prefix="[TOOL_RESULTS] ",
function_suffix="[/TOOL_RESULTS]",
)
)
Expand Down Expand Up @@ -124,7 +124,7 @@ class MixtralFunctionCallingAdapter(WrapperEngine):

@staticmethod
def _parse_tool_calls(content: str) -> tuple[str, list[ToolCall]]:
tool_json = re.search(r"\[TOOL_CALLS](.+)</s>", content, re.IGNORECASE | re.DOTALL)
tool_json = re.search(r"\[TOOL_CALLS]\s*(.+)</s>", content, re.IGNORECASE | re.DOTALL)
if tool_json is None:
return content, []
actions = json.loads(tool_json.group(1))
Expand Down

0 comments on commit e4d4d22

Please sign in to comment.