Skip to content

Commit

Permalink
fix: openai llm retries (#1196)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Dec 11, 2024
1 parent 0e8a13d commit db3b7a3
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 48 deletions.
5 changes: 5 additions & 0 deletions .changeset/seven-cows-rush.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-openai": patch
---

fix: openai llm retries
114 changes: 68 additions & 46 deletions livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import datetime
import os
from dataclasses import dataclass
from typing import Any, Awaitable, Literal, MutableSet, Union
from typing import Any, Literal, MutableSet, Union

import aiohttp
import httpx
Expand Down Expand Up @@ -640,50 +640,25 @@ def chat(
) -> "LLMStream":
if parallel_tool_calls is None:
parallel_tool_calls = self._opts.parallel_tool_calls

if tool_choice is None:
tool_choice = self._opts.tool_choice
opts: dict[str, Any] = dict()
if fnc_ctx and len(fnc_ctx.ai_functions) > 0:
fncs_desc = []
for fnc in fnc_ctx.ai_functions.values():
fncs_desc.append(build_oai_function_description(fnc, self.capabilities))

opts["tools"] = fncs_desc

if fnc_ctx and parallel_tool_calls is not None:
opts["parallel_tool_calls"] = parallel_tool_calls
if tool_choice is not None:
if isinstance(tool_choice, ToolChoice):
opts["tool_choice"] = {
"type": "function",
"function": {"name": tool_choice.name},
}
else:
opts["tool_choice"] = tool_choice

user = self._opts.user or openai.NOT_GIVEN

if temperature is None:
temperature = self._opts.temperature

messages = _build_oai_context(chat_ctx, id(self))

cmp = self._client.chat.completions.create(
messages=messages,
model=self._opts.model,
n=n,
temperature=temperature,
stream_options={"include_usage": True},
stream=True,
user=user,
**opts,
)

return LLMStream(
self,
oai_stream=cmp,
client=self._client,
model=self._opts.model,
user=self._opts.user,
chat_ctx=chat_ctx,
fnc_ctx=fnc_ctx,
conn_options=conn_options,
n=n,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)


Expand All @@ -692,33 +667,80 @@ def __init__(
self,
llm: LLM,
*,
oai_stream: Awaitable[openai.AsyncStream[ChatCompletionChunk]],
client: openai.AsyncClient,
model: str | ChatModels,
user: str | None,
chat_ctx: llm.ChatContext,
fnc_ctx: llm.FunctionContext | None,
conn_options: APIConnectOptions,
fnc_ctx: llm.FunctionContext | None,
temperature: float | None,
n: int | None,
parallel_tool_calls: bool | None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
) -> None:
super().__init__(
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
)
self._client = client
self._model = model
self._llm: LLM = llm
self._awaitable_oai_stream = oai_stream
self._oai_stream: openai.AsyncStream[ChatCompletionChunk] | None = None

self._user = user
self._temperature = temperature
self._n = n
self._parallel_tool_calls = parallel_tool_calls
self._tool_choice = tool_choice

async def _run(self) -> None:
if hasattr(self._llm._client, "_refresh_credentials"):
await self._llm._client._refresh_credentials()

# current function call that we're waiting for full completion (args are streamed)
# (defined inside the _run method to make sure the state is reset for each run/attempt)
self._oai_stream: openai.AsyncStream[ChatCompletionChunk] | None = None
self._tool_call_id: str | None = None
self._fnc_name: str | None = None
self._fnc_raw_arguments: str | None = None
self._tool_index: int | None = None

async def _run(self) -> None:
if hasattr(self._llm._client, "_refresh_credentials"):
await self._llm._client._refresh_credentials()

try:
if not self._oai_stream:
self._oai_stream = await self._awaitable_oai_stream
opts: dict[str, Any] = dict()
if self._fnc_ctx and len(self._fnc_ctx.ai_functions) > 0:
fncs_desc = []
for fnc in self._fnc_ctx.ai_functions.values():
fncs_desc.append(
build_oai_function_description(fnc, self._llm._capabilities)
)

opts["tools"] = fncs_desc
if self._parallel_tool_calls is not None:
opts["parallel_tool_calls"] = self._parallel_tool_calls

if self._tool_choice is not None:
if isinstance(self._tool_choice, ToolChoice):
# specific function
opts["tool_choice"] = {
"type": "function",
"function": {"name": self._tool_choice.name},
}
else:
opts["tool_choice"] = self._tool_choice

user = self._user or openai.NOT_GIVEN
messages = _build_oai_context(self._chat_ctx, id(self))

stream = await self._client.chat.completions.create(
messages=messages,
model=self._model,
n=self._n,
temperature=self._temperature,
stream_options={"include_usage": True},
stream=True,
user=user,
**opts,
)

async with self._oai_stream as stream:
async with stream:
async for chunk in stream:
for choice in chunk.choices:
chat_chunk = self._parse_choice(chunk.id, choice)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Annotated, Callable, Literal, Optional, Union

import pytest
from livekit.agents import llm
from livekit.agents import APIConnectionError, llm
from livekit.agents.llm import ChatContext, FunctionContext, TypeInfo, ai_callable
from livekit.plugins import anthropic, openai

Expand Down Expand Up @@ -235,7 +235,7 @@ def change_volume(
) -> None: ...

if not input_llm.capabilities.supports_choices_on_int:
with pytest.raises(ValueError, match="which is not supported by this model"):
with pytest.raises(APIConnectionError):
stream = await _request_fnc_call(input_llm, "Set the volume to 30", fnc_ctx)
else:
stream = await _request_fnc_call(input_llm, "Set the volume to 30", fnc_ctx)
Expand Down

0 comments on commit db3b7a3

Please sign in to comment.