diff --git a/.changeset/seven-cows-rush.md b/.changeset/seven-cows-rush.md new file mode 100644 index 000000000..57b3917ca --- /dev/null +++ b/.changeset/seven-cows-rush.md @@ -0,0 +1,5 @@ +--- +"livekit-plugins-openai": patch +--- + +fix: openai llm retries diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 63d7a7981..0945ce941 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -19,7 +19,7 @@ import datetime import os from dataclasses import dataclass -from typing import Any, Awaitable, Literal, MutableSet, Union +from typing import Any, Literal, MutableSet, Union import aiohttp import httpx @@ -640,50 +640,25 @@ def chat( ) -> "LLMStream": if parallel_tool_calls is None: parallel_tool_calls = self._opts.parallel_tool_calls + if tool_choice is None: tool_choice = self._opts.tool_choice - opts: dict[str, Any] = dict() - if fnc_ctx and len(fnc_ctx.ai_functions) > 0: - fncs_desc = [] - for fnc in fnc_ctx.ai_functions.values(): - fncs_desc.append(build_oai_function_description(fnc, self.capabilities)) - - opts["tools"] = fncs_desc - - if fnc_ctx and parallel_tool_calls is not None: - opts["parallel_tool_calls"] = parallel_tool_calls - if tool_choice is not None: - if isinstance(tool_choice, ToolChoice): - opts["tool_choice"] = { - "type": "function", - "function": {"name": tool_choice.name}, - } - else: - opts["tool_choice"] = tool_choice - - user = self._opts.user or openai.NOT_GIVEN + if temperature is None: temperature = self._opts.temperature - messages = _build_oai_context(chat_ctx, id(self)) - - cmp = self._client.chat.completions.create( - messages=messages, - model=self._opts.model, - n=n, - temperature=temperature, - stream_options={"include_usage": True}, - stream=True, - user=user, - **opts, - ) - return LLMStream( self, - oai_stream=cmp, + client=self._client, + model=self._opts.model, + user=self._opts.user, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options, + n=n, + temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, ) @@ -692,33 +667,80 @@ def __init__( self, llm: LLM, *, - oai_stream: Awaitable[openai.AsyncStream[ChatCompletionChunk]], + client: openai.AsyncClient, + model: str | ChatModels, + user: str | None, chat_ctx: llm.ChatContext, - fnc_ctx: llm.FunctionContext | None, conn_options: APIConnectOptions, + fnc_ctx: llm.FunctionContext | None, + temperature: float | None, + n: int | None, + parallel_tool_calls: bool | None, + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]], ) -> None: super().__init__( llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options ) + self._client = client + self._model = model self._llm: LLM = llm - self._awaitable_oai_stream = oai_stream - self._oai_stream: openai.AsyncStream[ChatCompletionChunk] | None = None + + self._user = user + self._temperature = temperature + self._n = n + self._parallel_tool_calls = parallel_tool_calls + self._tool_choice = tool_choice + + async def _run(self) -> None: + if hasattr(self._llm._client, "_refresh_credentials"): + await self._llm._client._refresh_credentials() # current function call that we're waiting for full completion (args are streamed) + # (defined inside the _run method to make sure the state is reset for each run/attempt) + self._oai_stream: openai.AsyncStream[ChatCompletionChunk] | None = None self._tool_call_id: str | None = None self._fnc_name: str | None = None self._fnc_raw_arguments: str | None = None self._tool_index: int | None = None - async def _run(self) -> None: - if hasattr(self._llm._client, "_refresh_credentials"): - await self._llm._client._refresh_credentials() - try: - if not self._oai_stream: - self._oai_stream = await self._awaitable_oai_stream + opts: dict[str, Any] = dict() + if self._fnc_ctx and len(self._fnc_ctx.ai_functions) > 0: + fncs_desc = [] + for fnc in self._fnc_ctx.ai_functions.values(): + fncs_desc.append( + build_oai_function_description(fnc, self._llm._capabilities) + ) + + opts["tools"] = fncs_desc + if self._parallel_tool_calls is not None: + opts["parallel_tool_calls"] = self._parallel_tool_calls + + if self._tool_choice is not None: + if isinstance(self._tool_choice, ToolChoice): + # specific function + opts["tool_choice"] = { + "type": "function", + "function": {"name": self._tool_choice.name}, + } + else: + opts["tool_choice"] = self._tool_choice + + user = self._user or openai.NOT_GIVEN + messages = _build_oai_context(self._chat_ctx, id(self)) + + stream = await self._client.chat.completions.create( + messages=messages, + model=self._model, + n=self._n, + temperature=self._temperature, + stream_options={"include_usage": True}, + stream=True, + user=user, + **opts, + ) - async with self._oai_stream as stream: + async with stream: async for chunk in stream: for choice in chunk.choices: chat_chunk = self._parse_choice(chunk.id, choice) diff --git a/tests/test_llm.py b/tests/test_llm.py index b1903868a..906e27462 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -5,7 +5,7 @@ from typing import Annotated, Callable, Literal, Optional, Union import pytest -from livekit.agents import llm +from livekit.agents import APIConnectionError, llm from livekit.agents.llm import ChatContext, FunctionContext, TypeInfo, ai_callable from livekit.plugins import anthropic, openai @@ -235,7 +235,7 @@ def change_volume( ) -> None: ... if not input_llm.capabilities.supports_choices_on_int: - with pytest.raises(ValueError, match="which is not supported by this model"): + with pytest.raises(APIConnectionError): stream = await _request_fnc_call(input_llm, "Set the volume to 30", fnc_ctx) else: stream = await _request_fnc_call(input_llm, "Set the volume to 30", fnc_ctx)