Skip to content

Commit

Permalink
Allowing tool_calls in expected response from OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Oct 8, 2024
1 parent 098c050 commit 09b40bf
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/aviary/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,19 @@ async def __call__(
"""Run a completion that selects a tool in tools given the messages."""
completion_kwargs: dict[str, Any] = {}
# SEE: https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter
expected_finish_reason: str = "tool_calls"
expected_finish_reason: set[str] = {"tool_calls"}
if isinstance(tool_choice, Tool):
completion_kwargs["tool_choice"] = {
"type": "function",
"function": {"name": tool_choice.info.name},
}
expected_finish_reason = "stop"
expected_finish_reason = {"stop"} # TODO: should this be .add("stop") too?
elif tool_choice is not None:
completion_kwargs["tool_choice"] = tool_choice
if tool_choice == self.TOOL_CHOICE_REQUIRED:
expected_finish_reason = "stop"
# Even though docs say it should be just 'stop',
# in practice 'tool_calls' shows up too
expected_finish_reason.add("stop")

model_response = await self._bound_acompletion(
messages=MessagesAdapter.dump_python(
Expand All @@ -184,11 +186,11 @@ async def __call__(
f" choices, full response was {model_response}."
)
choice = model_response.choices[0]
if choice.finish_reason != expected_finish_reason:
if choice.finish_reason not in expected_finish_reason:
raise MalformedMessageError(
f"Expected finish reason {expected_finish_reason!r} in LiteLLM model"
f" response, got {choice.finish_reason!r}, full response was"
f" {model_response}."
f"Expected a finish reason in {expected_finish_reason} in LiteLLM"
f" model response, got finish reason {choice.finish_reason!r}, full"
f" response was {model_response} and tool choice was {tool_choice}."
)
usage = model_response.usage
return ToolRequestMessage(
Expand Down

0 comments on commit 09b40bf

Please sign in to comment.