Skip to content

Commit

Permalink
Move create_ai_function_info to function_context.py (#1260)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayeshp19 authored Dec 20, 2024
1 parent e32278b commit c57b4cc
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 194 deletions.
7 changes: 7 additions & 0 deletions .changeset/clever-lies-explode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"livekit-plugins-anthropic": patch
"livekit-plugins-openai": patch
"livekit-agents": patch
---

Moved create_ai_function_info to function_context.py for better reusability and reduce repetation
2 changes: 2 additions & 0 deletions livekit-agents/livekit/agents/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FunctionContext,
FunctionInfo,
TypeInfo,
_create_ai_function_info,
ai_callable,
)
from .llm import (
Expand Down Expand Up @@ -54,4 +55,5 @@
"FallbackAdapter",
"AvailabilityChangedEvent",
"ToolChoice",
"_create_ai_function_info",
]
94 changes: 94 additions & 0 deletions livekit-agents/livekit/agents/llm/function_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import enum
import functools
import inspect
import json
import types
import typing
from dataclasses import dataclass
Expand Down Expand Up @@ -303,3 +304,96 @@ def _is_optional_type(typ) -> Tuple[bool, Any]:
return True, non_none_args[0]

return False, None


def _create_ai_function_info(
fnc_ctx: FunctionContext,
tool_call_id: str,
fnc_name: str,
raw_arguments: str, # JSON string
) -> FunctionCallInfo:
if fnc_name not in fnc_ctx.ai_functions:
raise ValueError(f"AI function {fnc_name} not found")

parsed_arguments: dict[str, Any] = {}
try:
if raw_arguments: # ignore empty string
parsed_arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
raise ValueError(
f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}"
)

fnc_info = fnc_ctx.ai_functions[fnc_name]

# Ensure all necessary arguments are present and of the correct type.
sanitized_arguments: dict[str, Any] = {}
for arg_info in fnc_info.arguments.values():
if arg_info.name not in parsed_arguments:
if arg_info.default is inspect.Parameter.empty:
raise ValueError(
f"AI function {fnc_name} missing required argument {arg_info.name}"
)
continue

arg_value = parsed_arguments[arg_info.name]
is_optional, inner_th = _is_optional_type(arg_info.type)

if typing.get_origin(inner_th) is not None:
if not isinstance(arg_value, list):
raise ValueError(
f"AI function {fnc_name} argument {arg_info.name} should be a list"
)

inner_type = typing.get_args(inner_th)[0]
sanitized_value = [
_sanitize_primitive(
value=v,
expected_type=inner_type,
choices=arg_info.choices,
)
for v in arg_value
]
else:
sanitized_value = _sanitize_primitive(
value=arg_value,
expected_type=inner_th,
choices=arg_info.choices,
)

sanitized_arguments[arg_info.name] = sanitized_value

return FunctionCallInfo(
tool_call_id=tool_call_id,
raw_arguments=raw_arguments,
function_info=fnc_info,
arguments=sanitized_arguments,
)


def _sanitize_primitive(
*, value: Any, expected_type: type, choices: tuple | None
) -> Any:
if expected_type is str:
if not isinstance(value, str):
raise ValueError(f"expected str, got {type(value)}")
elif expected_type in (int, float):
if not isinstance(value, (int, float)):
raise ValueError(f"expected number, got {type(value)}")

if expected_type is int:
if value % 1 != 0:
raise ValueError("expected int, got float")

value = int(value)
elif expected_type is float:
value = float(value)

elif expected_type is bool:
if not isinstance(value, bool):
raise ValueError(f"expected bool, got {type(value)}")

if choices and value not in choices:
raise ValueError(f"invalid value {value}, not in {choices}")

return value
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Awaitable,
List,
Literal,
Tuple,
Union,
cast,
get_args,
Expand All @@ -41,7 +40,10 @@
utils,
)
from livekit.agents.llm import ToolChoice
from livekit.agents.llm.function_context import _is_optional_type
from livekit.agents.llm.function_context import (
_create_ai_function_info,
_is_optional_type,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

import anthropic
Expand Down Expand Up @@ -487,67 +489,6 @@ def _build_anthropic_image_content(
)


def _create_ai_function_info(
fnc_ctx: llm.function_context.FunctionContext,
tool_call_id: str,
fnc_name: str,
raw_arguments: str, # JSON string
) -> llm.function_context.FunctionCallInfo:
if fnc_name not in fnc_ctx.ai_functions:
raise ValueError(f"AI function {fnc_name} not found")

parsed_arguments: dict[str, Any] = {}
try:
if raw_arguments: # ignore empty string
parsed_arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
raise ValueError(
f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}"
)

fnc_info = fnc_ctx.ai_functions[fnc_name]

# Ensure all necessary arguments are present and of the correct type.
sanitized_arguments: dict[str, Any] = {}
for arg_info in fnc_info.arguments.values():
if arg_info.name not in parsed_arguments:
if arg_info.default is inspect.Parameter.empty:
raise ValueError(
f"AI function {fnc_name} missing required argument {arg_info.name}"
)
continue

arg_value = parsed_arguments[arg_info.name]
is_optional, inner_th = _is_optional_type(arg_info.type)

if get_origin(inner_th) is not None:
if not isinstance(arg_value, list):
raise ValueError(
f"AI function {fnc_name} argument {arg_info.name} should be a list"
)

inner_type = get_args(inner_th)[0]
sanitized_value = [
_sanitize_primitive(
value=v, expected_type=inner_type, choices=arg_info.choices
)
for v in arg_value
]
else:
sanitized_value = _sanitize_primitive(
value=arg_value, expected_type=inner_th, choices=arg_info.choices
)

sanitized_arguments[arg_info.name] = sanitized_value

return llm.function_context.FunctionCallInfo(
tool_call_id=tool_call_id,
raw_arguments=raw_arguments,
function_info=fnc_info,
arguments=sanitized_arguments,
)


def _build_function_description(
fnc_info: llm.function_context.FunctionInfo,
) -> anthropic.types.ToolParam:
Expand Down Expand Up @@ -598,31 +539,3 @@ def type2str(t: type) -> str:
"description": fnc_info.description,
"input_schema": input_schema,
}


def _sanitize_primitive(
*, value: Any, expected_type: type, choices: Tuple[Any] | None
) -> Any:
if expected_type is str:
if not isinstance(value, str):
raise ValueError(f"expected str, got {type(value)}")
elif expected_type in (int, float):
if not isinstance(value, (int, float)):
raise ValueError(f"expected number, got {type(value)}")

if expected_type is int:
if value % 1 != 0:
raise ValueError("expected int, got float")

value = int(value)
elif expected_type is float:
value = float(value)

elif expected_type is bool:
if not isinstance(value, bool):
raise ValueError(f"expected bool, got {type(value)}")

if choices and value not in choices:
raise ValueError(f"invalid value {value}, not in {choices}")

return value
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse):
except Exception:
logger.exception("failed to process AssemblyAI message")

ws: aiohttp.ClientWebSocketResponse | None = None

while True:
try:
ws = await self._connect_ws()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,79 +15,13 @@
from __future__ import annotations

import inspect
import json
import typing
from typing import Any

from livekit.agents.llm import function_context, llm
from livekit.agents.llm.function_context import _is_optional_type

__all__ = ["build_oai_function_description", "create_ai_function_info"]


def create_ai_function_info(
fnc_ctx: function_context.FunctionContext,
tool_call_id: str,
fnc_name: str,
raw_arguments: str, # JSON string
) -> function_context.FunctionCallInfo:
if fnc_name not in fnc_ctx.ai_functions:
raise ValueError(f"AI function {fnc_name} not found")

parsed_arguments: dict[str, Any] = {}
try:
if raw_arguments: # ignore empty string
parsed_arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
raise ValueError(
f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}"
)

fnc_info = fnc_ctx.ai_functions[fnc_name]

# Ensure all necessary arguments are present and of the correct type.
sanitized_arguments: dict[str, Any] = {}
for arg_info in fnc_info.arguments.values():
if arg_info.name not in parsed_arguments:
if arg_info.default is inspect.Parameter.empty:
raise ValueError(
f"AI function {fnc_name} missing required argument {arg_info.name}"
)
continue

arg_value = parsed_arguments[arg_info.name]
is_optional, inner_th = _is_optional_type(arg_info.type)

if typing.get_origin(inner_th) is not None:
if not isinstance(arg_value, list):
raise ValueError(
f"AI function {fnc_name} argument {arg_info.name} should be a list"
)

inner_type = typing.get_args(inner_th)[0]
sanitized_value = [
_sanitize_primitive(
value=v,
expected_type=inner_type,
choices=arg_info.choices,
)
for v in arg_value
]
else:
sanitized_value = _sanitize_primitive(
value=arg_value,
expected_type=inner_th,
choices=arg_info.choices,
)

sanitized_arguments[arg_info.name] = sanitized_value

return function_context.FunctionCallInfo(
tool_call_id=tool_call_id,
raw_arguments=raw_arguments,
function_info=fnc_info,
arguments=sanitized_arguments,
)
__all__ = ["build_oai_function_description"]


def build_oai_function_description(
Expand Down Expand Up @@ -156,31 +90,3 @@ def type2str(t: type) -> str:
},
},
}


def _sanitize_primitive(
*, value: Any, expected_type: type, choices: tuple | None
) -> Any:
if expected_type is str:
if not isinstance(value, str):
raise ValueError(f"expected str, got {type(value)}")
elif expected_type in (int, float):
if not isinstance(value, (int, float)):
raise ValueError(f"expected number, got {type(value)}")

if expected_type is int:
if value % 1 != 0:
raise ValueError("expected int, got float")

value = int(value)
elif expected_type is float:
value = float(value)

elif expected_type is bool:
if not isinstance(value, bool):
raise ValueError(f"expected bool, got {type(value)}")

if choices and value not in choices:
raise ValueError(f"invalid value {value}, not in {choices}")

return value
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,14 @@
APITimeoutError,
llm,
)
from livekit.agents.llm import ToolChoice
from livekit.agents.llm import ToolChoice, _create_ai_function_info
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

import openai
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from openai.types.chat.chat_completion_chunk import Choice

from ._oai_api import (
build_oai_function_description,
create_ai_function_info,
)
from ._oai_api import build_oai_function_description
from .log import logger
from .models import (
CerebrasChatModels,
Expand Down Expand Up @@ -840,7 +837,7 @@ def _try_build_function(self, id: str, choice: Choice) -> llm.ChatChunk | None:
)
return None

fnc_info = create_ai_function_info(
fnc_info = _create_ai_function_info(
self._fnc_ctx, self._tool_call_id, self._fnc_name, self._fnc_raw_arguments
)

Expand Down
Loading

0 comments on commit c57b4cc

Please sign in to comment.