forked from livekit/agents
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
voiceassistant: fix will_synthesize_assistant_reply race (livekit#638)
- Loading branch information
1 parent
7c707a9
commit 16a7ef1
Showing
3 changed files
with
245 additions
and
120 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
"livekit-agents": patch | ||
--- | ||
|
||
voiceassistant: fix will_synthesize_assistant_reply race |
137 changes: 137 additions & 0 deletions
137
livekit-agents/livekit/agents/voice_assistant/speech_handle.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from typing import AsyncIterable | ||
|
||
from .. import utils | ||
from ..llm import LLMStream | ||
from .agent_output import SynthesisHandle | ||
|
||
|
||
class SpeechHandle: | ||
def __init__( | ||
self, | ||
*, | ||
id: str, | ||
allow_interruptions: bool, | ||
add_to_chat_ctx: bool, | ||
is_reply: bool, | ||
user_question: str, | ||
) -> None: | ||
self._id = id | ||
self._allow_interruptions = allow_interruptions | ||
self._add_to_chat_ctx = add_to_chat_ctx | ||
|
||
# is_reply is True when the speech is answering to a user question | ||
self._is_reply = is_reply | ||
self._user_question = user_question | ||
|
||
self._init_fut: asyncio.Future[None] = asyncio.Future() | ||
self._initialized = False | ||
|
||
# source and synthesis_handle are None until the speech is initialized | ||
self._source: str | LLMStream | AsyncIterable[str] | None = None | ||
self._synthesis_handle: SynthesisHandle | None = None | ||
|
||
@staticmethod | ||
def create_assistant_reply( | ||
*, | ||
allow_interruptions: bool, | ||
add_to_chat_ctx: bool, | ||
user_question: str, | ||
) -> SpeechHandle: | ||
return SpeechHandle( | ||
id=utils.shortuuid(), | ||
allow_interruptions=allow_interruptions, | ||
add_to_chat_ctx=add_to_chat_ctx, | ||
is_reply=True, | ||
user_question=user_question, | ||
) | ||
|
||
@staticmethod | ||
def create_assistant_speech( | ||
*, | ||
allow_interruptions: bool, | ||
add_to_chat_ctx: bool, | ||
) -> SpeechHandle: | ||
return SpeechHandle( | ||
id=utils.shortuuid(), | ||
allow_interruptions=allow_interruptions, | ||
add_to_chat_ctx=add_to_chat_ctx, | ||
is_reply=False, | ||
user_question="", | ||
) | ||
|
||
async def wait_for_initialization(self) -> None: | ||
await asyncio.shield(self._init_fut) | ||
|
||
def initialize( | ||
self, | ||
*, | ||
source: str | LLMStream | AsyncIterable[str], | ||
synthesis_handle: SynthesisHandle, | ||
) -> None: | ||
if self.interrupted: | ||
raise RuntimeError("speech is interrupted") | ||
|
||
self._source = source | ||
self._synthesis_handle = synthesis_handle | ||
self._initialized = True | ||
self._init_fut.set_result(None) | ||
|
||
@property | ||
def id(self) -> str: | ||
return self._id | ||
|
||
@property | ||
def allow_interruptions(self) -> bool: | ||
return self._allow_interruptions | ||
|
||
@property | ||
def add_to_chat_ctx(self) -> bool: | ||
return self._add_to_chat_ctx | ||
|
||
@property | ||
def source(self) -> str | LLMStream | AsyncIterable[str]: | ||
if self._source is None: | ||
raise RuntimeError("speech not initialized") | ||
return self._source | ||
|
||
@property | ||
def synthesis_handle(self) -> SynthesisHandle: | ||
if self._synthesis_handle is None: | ||
raise RuntimeError("speech not initialized") | ||
return self._synthesis_handle | ||
|
||
@synthesis_handle.setter | ||
def synthesis_handle(self, synthesis_handle: SynthesisHandle) -> None: | ||
"""synthesis handle can be replaced for the same speech. | ||
This is useful when we need to do a new generation. (e.g for automatic function call answers)""" | ||
if self._synthesis_handle is None: | ||
raise RuntimeError("speech not initialized") | ||
|
||
self._synthesis_handle = synthesis_handle | ||
|
||
@property | ||
def initialized(self) -> bool: | ||
return self._initialized | ||
|
||
@property | ||
def is_reply(self) -> bool: | ||
return self._is_reply | ||
|
||
@property | ||
def user_question(self) -> str: | ||
return self._user_question | ||
|
||
@property | ||
def interrupted(self) -> bool: | ||
return self._init_fut.cancelled() or ( | ||
self._synthesis_handle is not None and self._synthesis_handle.interrupted | ||
) | ||
|
||
def interrupt(self) -> None: | ||
self._init_fut.cancel() | ||
|
||
if self._synthesis_handle is not None: | ||
self._synthesis_handle.interrupt() |
Oops, something went wrong.