Skip to content

Commit

Permalink
fix: NaiveJSONToolCallParser streaming, kwarg passes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Jan 25, 2025
1 parent 738db13 commit f53c1fb
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
4 changes: 3 additions & 1 deletion kani/tool_parsers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def __init__(
tool_call_end_token: str = "<|tool▁outputs▁end|>",
**kwargs,
):
super().__init__(*args, tool_call_start_token, tool_call_end_token, **kwargs)
super().__init__(
*args, tool_call_start_token=tool_call_start_token, tool_call_end_token=tool_call_end_token, **kwargs
)

def parse_tool_calls(self, content: str):
tool_content_match = re.search(
Expand Down
52 changes: 51 additions & 1 deletion kani/tool_parsers/json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

from kani.models import FunctionCall, ToolCall
from kani.engines import Completion
from kani.models import ChatMessage, FunctionCall, ToolCall
from .base import BaseToolCallParser


Expand All @@ -20,6 +21,9 @@ class NaiveJSONToolCallParser(BaseToolCallParser):
then assume it is a function call. Otherwise, return the content unchanged.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, tool_call_start_token="", tool_call_end_token="", **kwargs)

def parse_tool_calls(self, content: str) -> tuple[str, list[ToolCall]]:
"""Given the string completion of the model, return the content without tool calls and the parsed tool calls."""
try:
Expand All @@ -31,3 +35,49 @@ def parse_tool_calls(self, content: str) -> tuple[str, list[ToolCall]]:
except json.JSONDecodeError:
return content, []
return content, []

async def stream(self, messages, functions=None, **hyperparams):
# special case - if we see a { at start of message, defer until end of message to see if it's a function call
# otherwise stream as normal
content_parts = []
seen_non_tool_call_token = False
in_tool_call = False
inner_completion = None

# consume from the inner iterator, yielding as normal until we see a tool call or a completion
async for elem in self.engine.stream(messages, functions, **hyperparams):
if isinstance(elem, str):
content_parts.append(elem)
# if we see {, stop yielding and start buffering
if elem.startswith("{") and not seen_non_tool_call_token:
in_tool_call = True
# otherwise yield the string
if not in_tool_call:
seen_non_tool_call_token = True
yield elem
else:
# save the inner completion
inner_completion = elem

# we have consumed all the elements - construct a new completion
# if we don't have a tool call we can just yield the inner completion
if not in_tool_call and inner_completion:
yield inner_completion
# otherwise, parse tool calls from the content (preserving inner tool calls if necessary)
else:
content = "".join(content_parts)
# noinspection DuplicatedCode
content, tool_calls = self.parse_tool_calls(content)
if inner_completion:
tool_calls = (inner_completion.message.tool_calls or []) + tool_calls
prompt_tokens = inner_completion.prompt_tokens
completion_tokens = inner_completion.completion_tokens
else:
prompt_tokens = None
completion_tokens = None
clean_content = content.strip()
yield Completion(
ChatMessage.assistant(clean_content, tool_calls=tool_calls),
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
4 changes: 3 additions & 1 deletion kani/tool_parsers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class MistralToolCallParser(BaseToolCallParser):
"""

def __init__(self, *args, tool_call_start_token: str = "[TOOL_CALLS]", tool_call_end_token: str = "</s>", **kwargs):
super().__init__(*args, tool_call_start_token, tool_call_end_token, **kwargs)
super().__init__(
*args, tool_call_start_token=tool_call_start_token, tool_call_end_token=tool_call_end_token, **kwargs
)

def parse_tool_calls(self, content: str):
tool_json = re.search(
Expand Down

0 comments on commit f53c1fb

Please sign in to comment.