Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai functions improvements follow up #393

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions examples/voice-assistant/minimal_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,20 @@
from livekit.agents import JobContext, JobRequest, WorkerOptions, cli
from livekit.agents.llm import (
ChatContext,
ChatMessage,

Check failure on line 7 in examples/voice-assistant/minimal_assistant.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

examples/voice-assistant/minimal_assistant.py:7:5: F401 `livekit.agents.llm.ChatMessage` imported but unused
ChatRole,

Check failure on line 8 in examples/voice-assistant/minimal_assistant.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

examples/voice-assistant/minimal_assistant.py:8:5: F401 `livekit.agents.llm.ChatRole` imported but unused
)
from livekit.agents.voice_assistant import VoiceAssistant
from livekit.plugins import deepgram, openai, silero


async def entrypoint(ctx: JobContext):
initial_ctx = ChatContext(
messages=[
ChatMessage(
role=ChatRole.SYSTEM,
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
]
initial_ctx = ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)

assistant = VoiceAssistant(
Expand Down
24 changes: 7 additions & 17 deletions livekit-agents/livekit/agents/llm/_oai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from __future__ import annotations

import asyncio

Check failure on line 17 in livekit-agents/livekit/agents/llm/_oai_api.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

livekit-agents/livekit/agents/llm/_oai_api.py:17:8: F401 `asyncio` imported but unused
import functools

Check failure on line 18 in livekit-agents/livekit/agents/llm/_oai_api.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

livekit-agents/livekit/agents/llm/_oai_api.py:18:8: F401 `functools` imported but unused
import inspect
import json
import typing
Expand All @@ -24,17 +24,17 @@
from . import function_context

__all__ = [
"create_ai_function_task",

Check failure on line 27 in livekit-agents/livekit/agents/llm/_oai_api.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F822)

livekit-agents/livekit/agents/llm/_oai_api.py:27:5: F822 Undefined name `create_ai_function_task` in `__all__`
"build_oai_function_description",
]


def create_ai_function_task(
def create_ai_function_info(
fnc_ctx: function_context.FunctionContext,
tool_call_id: str,
fnc_name: str,
raw_arguments: str, # JSON string
) -> tuple[asyncio.Task[Any], function_context.CalledFunction]:
) -> function_context.FunctionCallInfo:
if fnc_name not in fnc_ctx.ai_functions:
raise ValueError(f"AI function {fnc_name} not found")

Expand Down Expand Up @@ -80,21 +80,11 @@

sanitized_arguments[arg_info.name] = sanitized_value

func = functools.partial(fnc_info.callable, **sanitized_arguments)
if asyncio.iscoroutinefunction(fnc_info.callable):
task = asyncio.create_task(func())
else:
task = asyncio.create_task(asyncio.to_thread(func))

return (
task,
function_context.CalledFunction(
tool_call_id=tool_call_id,
raw_arguments=raw_arguments,
function_info=fnc_info,
arguments=sanitized_arguments,
task=task,
),
return function_context.FunctionCallInfo(
tool_call_id=tool_call_id,
raw_arguments=raw_arguments,
function_info=fnc_info,
arguments=sanitized_arguments,
)


Expand Down
15 changes: 8 additions & 7 deletions livekit-agents/livekit/agents/llm/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ChatMessage:
role: ChatRole
name: str | None = None
content: str | list[str | ChatImage] | None = None
tool_calls: list[function_context.CalledFunction] | None = None
tool_calls: list[function_context.FunctionCallInfo] | None = None
tool_call_id: str | None = None

@staticmethod
Expand All @@ -50,20 +50,21 @@ def create_tool_from_called_function(
if not called_function.task.done():
raise ValueError("cannot create a tool result from a running ai function")

content = called_function.task.result()
if called_function.task.exception() is not None:
content = f"Error: {called_function.task.exception}"
try:
content = called_function.task.result()
except BaseException as e:
content = f"Error: {e}"

return ChatMessage(
role="tool",
name=called_function.function_info.name,
name=called_function.call_info.function_info.name,
content=content,
tool_call_id=called_function.tool_call_id,
tool_call_id=called_function.call_info.tool_call_id,
)

@staticmethod
def create_tool_calls(
called_functions: list[function_context.CalledFunction],
called_functions: list[function_context.FunctionCallInfo],
) -> "ChatMessage":
return ChatMessage(
role="assistant",
Expand Down
29 changes: 28 additions & 1 deletion livekit-agents/livekit/agents/llm/function_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
import functools
import enum
import inspect
import typing
from dataclasses import dataclass, field
from typing import Any, Callable

from ..log import logger


class _UseDocMarker:

Check failure on line 28 in livekit-agents/livekit/agents/llm/function_context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

livekit-agents/livekit/agents/llm/function_context.py:15:1: I001 Import block is un-sorted or un-formatted
pass


Expand Down Expand Up @@ -57,12 +58,38 @@


@dataclass
class CalledFunction:
class FunctionCallInfo:
tool_call_id: str
function_info: FunctionInfo
raw_arguments: str
arguments: dict[str, Any]

def execute(self) -> CalledFunction:
function_info = self.function_info
func = functools.partial(function_info.callable, **self.arguments)
if asyncio.iscoroutinefunction(function_info.callable):
task = asyncio.create_task(func())
else:
task = asyncio.create_task(asyncio.to_thread(func))

called_fnc = CalledFunction(call_info=self, task=task)

def _on_done(fut):
try:
called_fnc.result = fut.result()
except BaseException as e:
called_fnc.exception = e

task.add_done_callback(_on_done)
return called_fnc


@dataclass
class CalledFunction:
call_info: FunctionCallInfo
task: asyncio.Task[Any]
result: Any | None = None
exception: BaseException | None = None


def ai_callable(
Expand Down
49 changes: 37 additions & 12 deletions livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from __future__ import annotations

import asyncio
import abc
from dataclasses import dataclass, field
from typing import AsyncIterator

from . import function_context
from .chat_context import ChatContext, ChatRole


@dataclass

Check failure on line 12 in livekit-agents/livekit/agents/llm/llm.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

livekit-agents/livekit/agents/llm/llm.py:1:1: I001 Import block is un-sorted or un-formatted
class ChoiceDelta:
role: ChatRole
content: str | None = None
tool_calls: list[function_context.CalledFunction] | None = None
tool_calls: list[function_context.FunctionCallInfo] | None = None


@dataclass
Expand All @@ -28,7 +29,7 @@

class LLM(abc.ABC):
@abc.abstractmethod
async def chat(
def chat(
self,
*,
chat_ctx: ChatContext,
Expand All @@ -39,24 +40,48 @@


class LLMStream(abc.ABC):
def __init__(self) -> None:
self._called_functions: list[function_context.CalledFunction] = []
def __init__(
self, *, chat_ctx: ChatContext, fnc_ctx: function_context.FunctionContext | None
) -> None:
self._function_calls_info: list[function_context.FunctionCallInfo] = []
self._tasks = set[asyncio.Task]()
self._chat_ctx = chat_ctx
self._fnc_ctx = fnc_ctx

@property
def called_functions(self) -> list[function_context.CalledFunction]:
def function_calls(self) -> list[function_context.FunctionCallInfo]:
"""List of called functions from this stream."""
return self._called_functions
return self._function_calls_info

@abc.abstractmethod
async def gather_function_results(
@property
def chat_ctx(self) -> ChatContext:
"""The initial chat context of this stream."""
return self._chat_ctx

@property
def fnc_ctx(self) -> function_context.FunctionContext | None:
"""The function context of this stream."""
return self._fnc_ctx

def execute_functions(
self,
) -> list[function_context.CalledFunction]: ...
) -> list[function_context.CalledFunction]:
"""Run all functions in this stream."""
called_functions = []
for fnc_info in self._function_calls_info:
called_fnc = fnc_info.execute()
called_functions.append(called_fnc)

return called_functions

async def aclose(self) -> None:
for task in self._tasks:
task.cancel()

await asyncio.gather(*self._tasks, return_exceptions=True)

def __aiter__(self) -> AsyncIterator[ChatChunk]:
return self

@abc.abstractmethod
async def __anext__(self) -> ChatChunk: ...

@abc.abstractmethod
async def aclose(self) -> None: ...
12 changes: 6 additions & 6 deletions livekit-agents/livekit/agents/utils/event_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

class EventEmitter(Generic[T]):
def __init__(self) -> None:
self._events: Dict[T, Set[Callable[[Any], None]]] = dict()
self._events: Dict[T, Set[Callable[..., Any]]] = dict()

def emit(self, event: T, *args: Any, **kwargs: Any) -> None:
if event in self._events:
callables = self._events[event].copy()
for callback in callables:
callback(*args, **kwargs)

def once(self, event: T, callback: Optional[Callable[[Any], None]] = None):
def once(self, event: T, callback: Optional[Callable[..., Any]] = None):
if callback is not None:

def once_callback(*args: Any, **kwargs: Any):
Expand All @@ -23,26 +23,26 @@ def once_callback(*args: Any, **kwargs: Any):
return self.on(event, once_callback)
else:

def decorator(callback: Callable[[Any], None]):
def decorator(callback: Callable[..., Any]):
self.once(event, callback)
return callback

return decorator

def on(self, event: T, callback: Optional[Callable[[Any], None]] = None):
def on(self, event: T, callback: Optional[Callable[..., Any]] = None):
if callback is not None:
if event not in self._events:
self._events[event] = set()
self._events[event].add(callback)
return callback
else:

def decorator(callback: Callable[[Any], None]):
def decorator(callback: Callable[..., Any]):
self.on(event, callback)
return callback

return decorator

def off(self, event: T, callback: Callable[[Any], None]) -> None:
def off(self, event: T, callback: Callable[..., Any]) -> None:
if event in self._events:
self._events[event].remove(callback)
23 changes: 16 additions & 7 deletions livekit-agents/livekit/agents/vad.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from enum import Enum, unique
from typing import AsyncIterator, List

from livekit import rtc


@unique
class VADEventType(Enum):
START_OF_SPEECH = 1
INFERENCE_DONE = 2
END_OF_SPEECH = 3
START_OF_SPEECH = "start_of_speech"
INFERENCE_DONE = "inference_done"
END_OF_SPEECH = "end_of_speech"


@dataclass
class VADEvent:
type: VADEventType
"""type of the event"""
samples_index: int
"""index of the samples of the event (when the event was fired)"""
duration: float = 0.0
"""duration of the speech in seconds (only for END_SPEAKING event)"""
"""index of the samples when the event was fired"""
duration: float
"""duration of the speech in seconds"""
frames: List[rtc.AudioFrame] = field(default_factory=list)
"""list of audio frames of the speech"""
probability: float = 0.0
Expand All @@ -31,6 +32,14 @@ class VADEvent:


class VAD(ABC):
def __init__(self, *, update_interval: float) -> None:
self._update_interval = update_interval

@property
def update_interval(self) -> float:
"""interval in seconds to update the VAD model"""
return self._update_interval

@abstractmethod
def stream(
self,
Expand Down
4 changes: 2 additions & 2 deletions livekit-agents/livekit/agents/voice_assistant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .assistant import AssistantCallContext, VoiceAssistant
from .assistant import VoiceAssistant

__all__ = ["VoiceAssistant", "AssistantCallContext"]
__all__ = ["VoiceAssistant"]
Loading
Loading