Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

voiceassistant: fix will_synthesize_assistant_reply race #638

Merged
merged 5 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fifty-dingos-call.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

voiceassistant: fix will_synthesize_assistant_reply race
137 changes: 137 additions & 0 deletions livekit-agents/livekit/agents/voice_assistant/speech_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from __future__ import annotations

import asyncio
from typing import AsyncIterable

from .. import utils
from ..llm import LLMStream
from .agent_output import SynthesisHandle


class SpeechHandle:
def __init__(
self,
*,
id: str,
allow_interruptions: bool,
add_to_chat_ctx: bool,
is_reply: bool,
user_question: str,
) -> None:
self._id = id
self._allow_interruptions = allow_interruptions
self._add_to_chat_ctx = add_to_chat_ctx

# is_reply is True when the speech is answering to a user question
self._is_reply = is_reply
self._user_question = user_question

self._init_fut: asyncio.Future[None] = asyncio.Future()
self._initialized = False

# source and synthesis_handle are None until the speech is initialized
self._source: str | LLMStream | AsyncIterable[str] | None = None
self._synthesis_handle: SynthesisHandle | None = None

@staticmethod
def create_assistant_reply(
*,
allow_interruptions: bool,
add_to_chat_ctx: bool,
user_question: str,
) -> SpeechHandle:
return SpeechHandle(
id=utils.shortuuid(),
allow_interruptions=allow_interruptions,
add_to_chat_ctx=add_to_chat_ctx,
is_reply=True,
user_question=user_question,
)

@staticmethod
def create_assistant_speech(
*,
allow_interruptions: bool,
add_to_chat_ctx: bool,
) -> SpeechHandle:
return SpeechHandle(
id=utils.shortuuid(),
allow_interruptions=allow_interruptions,
add_to_chat_ctx=add_to_chat_ctx,
is_reply=False,
user_question="",
)

async def wait_for_initialization(self) -> None:
await asyncio.shield(self._init_fut)

def initialize(
self,
*,
source: str | LLMStream | AsyncIterable[str],
synthesis_handle: SynthesisHandle,
) -> None:
if self.interrupted:
raise RuntimeError("speech is interrupted")

self._source = source
self._synthesis_handle = synthesis_handle
self._initialized = True
self._init_fut.set_result(None)

@property
def id(self) -> str:
return self._id

@property
def allow_interruptions(self) -> bool:
return self._allow_interruptions

@property
def add_to_chat_ctx(self) -> bool:
return self._add_to_chat_ctx

@property
def source(self) -> str | LLMStream | AsyncIterable[str]:
if self._source is None:
raise RuntimeError("speech not initialized")
return self._source

@property
def synthesis_handle(self) -> SynthesisHandle:
if self._synthesis_handle is None:
raise RuntimeError("speech not initialized")
return self._synthesis_handle

@synthesis_handle.setter
def synthesis_handle(self, synthesis_handle: SynthesisHandle) -> None:
"""synthesis handle can be replaced for the same speech.
This is useful when we need to do a new generation. (e.g for automatic function call answers)"""
if self._synthesis_handle is None:
raise RuntimeError("speech not initialized")

self._synthesis_handle = synthesis_handle

@property
def initialized(self) -> bool:
return self._initialized

@property
def is_reply(self) -> bool:
return self._is_reply

@property
def user_question(self) -> str:
return self._user_question

@property
def interrupted(self) -> bool:
return self._init_fut.cancelled() or (
self._synthesis_handle is not None and self._synthesis_handle.interrupted
)

def interrupt(self) -> None:
self._init_fut.cancel()

if self._synthesis_handle is not None:
self._synthesis_handle.interrupt()
Loading
Loading