From 6d7e9707d1f3bcd7c476207dbf1d28a9fcee2e09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 3 Sep 2024 23:40:12 -0700 Subject: [PATCH 01/13] wip --- .../agents/voice_assistant/agent_output.py | 16 +++++++++---- .../agents/voice_assistant/voice_assistant.py | 23 ++++++++++++++++++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/livekit-agents/livekit/agents/voice_assistant/agent_output.py b/livekit-agents/livekit/agents/voice_assistant/agent_output.py index ec50fde98..31ffbf89e 100644 --- a/livekit-agents/livekit/agents/voice_assistant/agent_output.py +++ b/livekit-agents/livekit/agents/voice_assistant/agent_output.py @@ -1,8 +1,9 @@ from __future__ import annotations import asyncio +import inspect import time -from typing import Any, AsyncIterable, Callable, Union +from typing import Any, AsyncIterable, Awaitable, Callable, Union from livekit import rtc @@ -12,7 +13,7 @@ from .agent_playout import AgentPlayout, PlayoutHandle from .log import logger -SpeechSource = Union[AsyncIterable[str], str] +SpeechSource = Union[AsyncIterable[str], str, Awaitable[str]] class SynthesisHandle: @@ -155,10 +156,15 @@ def _will_forward_transcription( @utils.log_exceptions(logger=logger) async def _synthesize_task(self, handle: SynthesisHandle) -> None: """Synthesize speech from the source""" - if isinstance(handle._speech_source, str): - co = _str_synthesis_task(handle._speech_source, handle) + speech_source = handle._speech_source + + if isinstance(speech_source, Awaitable): + speech_source = await speech_source + co = _str_synthesis_task(speech_source, handle) + elif isinstance(speech_source, str): + co = _str_synthesis_task(speech_source, handle) else: - co = _stream_synthesis_task(handle._speech_source, handle) + co = _stream_synthesis_task(speech_source, handle) synth = asyncio.create_task(co) synth.add_done_callback(lambda _: handle._buf_ch.close()) diff --git a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py index 248673c03..760f1b170 100644 --- a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py +++ b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py @@ -22,6 +22,13 @@ Union[Optional[LLMStream], Awaitable[Optional[LLMStream]]], ] + +WillSynthesizeAssistantSpeech = Callable[ + ["VoiceAssistant", Union[str, AsyncIterable[str]]], + Union[str, AsyncIterable[str], Awaitable[str]], +] + + EventTypes = Literal[ "user_started_speaking", "user_stopped_speaking", @@ -72,6 +79,12 @@ def _default_will_synthesize_assistant_reply( return assistant.llm.chat(chat_ctx=chat_ctx, fnc_ctx=assistant.fnc_ctx) +def _default_will_synthesize_assistant_speech( + assistant: VoiceAssistant, text: str | AsyncIterable[str] +) -> str | AsyncIterable[str]: + return text + + @dataclass(frozen=True) class _ImplOptions: allow_interruptions: bool @@ -79,6 +92,7 @@ class _ImplOptions: int_min_words: int preemptive_synthesis: bool will_synthesize_assistant_reply: WillSynthesizeAssistantReply + will_synthesize_assistant_speech: WillSynthesizeAssistantSpeech plotting: bool transcription: AssistantTranscriptionOptions @@ -124,6 +138,7 @@ def __init__( preemptive_synthesis: bool = True, transcription: AssistantTranscriptionOptions = AssistantTranscriptionOptions(), will_synthesize_assistant_reply: WillSynthesizeAssistantReply = _default_will_synthesize_assistant_reply, + will_synthesize_assistant_speech: WillSynthesizeAssistantSpeech = _default_will_synthesize_assistant_speech, plotting: bool = False, loop: asyncio.AbstractEventLoop | None = None, ) -> None: @@ -145,6 +160,9 @@ def __init__( transcription: Options for assistant transcription. will_synthesize_assistant_reply: Callback called when the assistant is about to synthesize a reply. This can be used to customize the reply (e.g: inject context/RAG). + will_synthesize_assistant_speech: Callback called when the assistant is about to + synthesize a speech. This can be used to customize text before the speech synthesis. + (e.g: editing the pronunciation of a word). plotting: Whether to enable plotting for debugging. matplotlib must be installed. loop: Event loop to use. Default to asyncio.get_event_loop(). """ @@ -158,6 +176,7 @@ def __init__( preemptive_synthesis=preemptive_synthesis, transcription=transcription, will_synthesize_assistant_reply=will_synthesize_assistant_reply, + will_synthesize_assistant_speech=will_synthesize_assistant_speech, ) self._plotter = AssistantPlotter(self._loop) @@ -712,9 +731,11 @@ def _synthesize_agent_speech( if isinstance(source, LLMStream): source = _llm_stream_to_str_iterable(speech_id, source) + speech_source = self._opts.will_synthesize_assistant_speech(self, source) + return self._agent_output.synthesize( speech_id=speech_id, - transcript=source, + transcript=speech_source, transcription=self._opts.transcription.agent_transcription, transcription_speed=self._opts.transcription.agent_transcription_speed, sentence_tokenizer=self._opts.transcription.sentence_tokenizer, From 13bfa071e9b00097c144b25574a3c98d88e677ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Tue, 3 Sep 2024 23:41:46 -0700 Subject: [PATCH 02/13] Update agent_output.py --- livekit-agents/livekit/agents/voice_assistant/agent_output.py | 1 - 1 file changed, 1 deletion(-) diff --git a/livekit-agents/livekit/agents/voice_assistant/agent_output.py b/livekit-agents/livekit/agents/voice_assistant/agent_output.py index 31ffbf89e..0317a2a82 100644 --- a/livekit-agents/livekit/agents/voice_assistant/agent_output.py +++ b/livekit-agents/livekit/agents/voice_assistant/agent_output.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import inspect import time from typing import Any, AsyncIterable, Awaitable, Callable, Union From 99668fe78ed19a1eaa3c845a3e27b9617ee1dedd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Wed, 4 Sep 2024 10:46:04 -0700 Subject: [PATCH 03/13] Create three-onions-destroy.md --- .changeset/three-onions-destroy.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/three-onions-destroy.md diff --git a/.changeset/three-onions-destroy.md b/.changeset/three-onions-destroy.md new file mode 100644 index 000000000..96634b947 --- /dev/null +++ b/.changeset/three-onions-destroy.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": patch +--- + +voiceassistant: add will_synthesize_assistant_speech From 7cbf6153a6979ced91e8a05223c1e42387af95ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 4 Sep 2024 12:54:42 -0700 Subject: [PATCH 04/13] wip --- .../voice-assistant/custom_pronunciation.py | 52 ++++++++++++++++++ .../livekit/agents/utils/__init__.py | 3 +- livekit-agents/livekit/agents/utils/misc.py | 55 ++++++++++++++++++- tests/test_tokenizer.py | 31 ++++++++++- 4 files changed, 138 insertions(+), 3 deletions(-) create mode 100644 examples/voice-assistant/custom_pronunciation.py diff --git a/examples/voice-assistant/custom_pronunciation.py b/examples/voice-assistant/custom_pronunciation.py new file mode 100644 index 000000000..15158a4dc --- /dev/null +++ b/examples/voice-assistant/custom_pronunciation.py @@ -0,0 +1,52 @@ +from __future__ import annotations +from typing import AsyncIterable +import asyncio + +from dotenv import load_dotenv +from livekit import rtc +from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm, utils +from livekit.agents.voice_assistant import VoiceAssistant +from livekit.plugins import deepgram, cartesia, openai, silero + +load_dotenv() + + +async def entrypoint(ctx: JobContext): + initial_ctx = llm.ChatContext().append( + role="system", + text=( + "You are a voice assistant created by LiveKit. Your interface with users will be voice. " + "You should use short and concise responses, and avoiding usage of unpronouncable punctuation." + ), + ) + + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + + async def _will_synthesize_assistant_speech( + assistant: VoiceAssistant, text: str | AsyncIterable[str] + ): + # Cartesia TTS is incorrectly pronouncing "LiveKit", so we'll replace it with a phonetic + # spelling + return utils.replace_words( + text=text, replacements={"LiveKit": "<>Kit"} + ) + + # also for this example, we also intensify the keyword "LiveKit" to make it more likely to be + # recognized with the STT + deepgram_stt = deepgram.STT(keywords=[("LiveKit", 2.0)]) + + assistant = VoiceAssistant( + vad=silero.VAD.load(), + stt=deepgram_stt, + llm=openai.LLM(), + tts=cartesia.TTS(), + chat_ctx=initial_ctx, + will_synthesize_assistant_speech=_will_synthesize_assistant_speech, + ) + assistant.start(ctx.room) + + await assistant.say("Hey, how can I help you today?", allow_interruptions=True) + + +if __name__ == "__main__": + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint)) diff --git a/livekit-agents/livekit/agents/utils/__init__.py b/livekit-agents/livekit/agents/utils/__init__.py index 361308572..c139cf0c9 100644 --- a/livekit-agents/livekit/agents/utils/__init__.py +++ b/livekit-agents/livekit/agents/utils/__init__.py @@ -2,7 +2,7 @@ from .event_emitter import EventEmitter from .exp_filter import ExpFilter from .log import log_exceptions -from .misc import AudioBuffer, merge_frames, shortuuid, time_ms +from .misc import AudioBuffer, merge_frames, shortuuid, time_ms, replace_words from .moving_average import MovingAverage __all__ = [ @@ -10,6 +10,7 @@ "merge_frames", "time_ms", "shortuuid", + "replace_words", "http_context", "ExpFilter", "MovingAverage", diff --git a/livekit-agents/livekit/agents/utils/misc.py b/livekit-agents/livekit/agents/utils/misc.py index 7720a53db..a926e39f1 100644 --- a/livekit-agents/livekit/agents/utils/misc.py +++ b/livekit-agents/livekit/agents/utils/misc.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import time import uuid -from typing import List, Union +from typing import AsyncIterable, List, Union, overload from livekit import rtc @@ -48,3 +50,54 @@ def time_ms() -> int: def shortuuid() -> str: return str(uuid.uuid4().hex)[:12] + + +@overload +def replace_words( + *, + text: str, + replacements: dict[str, str], +) -> str: ... + + +@overload +def replace_words( + *, + text: AsyncIterable[str], + replacements: dict[str, str], +) -> AsyncIterable[str]: ... + + +def replace_words( + *, + text: str | AsyncIterable[str], + replacements: dict[str, str], +) -> str | AsyncIterable[str]: + """ + Replace words in text with another str + Args: + text: text to replace words in + words: dictionary of words to replace + """ + if isinstance(text, str): + for word, replacement in replacements.items(): + text = text.replace(word, replacement) + return text + else: + + async def _replace_words(): + buffer = "" + async for chunk in text: + for char in chunk: + if not char.isspace(): + buffer += char + else: + if buffer: + yield replacements.get(buffer, buffer) + buffer = "" + yield char + + if buffer: + yield replacements.get(buffer, buffer) + + return _replace_words() diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 931713eeb..692878939 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,5 +1,5 @@ import pytest -from livekit.agents import tokenize +from livekit.agents import tokenize, utils from livekit.agents.tokenize import basic from livekit.plugins import nltk @@ -141,3 +141,32 @@ def test_hyphenate_word(): for i, word in enumerate(HYPHENATOR_TEXT): hyphenated = basic.hyphenate_word(word) assert hyphenated == HYPHENATOR_EXPECTED[i] + + +REPLACE_TEXT = "This is a test. Hello world, I'm creating this agents.. framework" +REPLACE_EXPECTED = ( + "This is a test. Hello universe, I'm creating this agents.. library" +) + +REPLACE_REPLACEMENTS = { + "world": "universe", + "framework": "library", +} + + +def test_replace_words(): + replaced = utils.replace_words(text=REPLACE_TEXT, replacements=REPLACE_REPLACEMENTS) + assert replaced == REPLACE_EXPECTED + + +async def text_replace_words_async(): + pattern = [1, 2, 4] + text = REPLACE_TEXT + chunks = [] + pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1)) + + for chunk_size in pattern_iter: + if not text: + break + chunks.append(text[:chunk_size]) + text = text[chunk_size:] From 2a8896fa454861d3999735964e75195cda127d43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Fri, 6 Sep 2024 15:56:24 -0700 Subject: [PATCH 05/13] Update voice_assistant.py --- .../agents/voice_assistant/voice_assistant.py | 48 ++++++++++++------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py index 760f1b170..197f42e8d 100644 --- a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py +++ b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings import asyncio import contextvars import time @@ -17,13 +18,15 @@ from .plotter import AssistantPlotter from .speech_handle import SpeechHandle -WillSynthesizeAssistantReply = Callable[ + +PreLLMGenerationCallback = Callable[ ["VoiceAssistant", ChatContext], Union[Optional[LLMStream], Awaitable[Optional[LLMStream]]], ] +WillSynthesizeAssistantReply = PreLLMGenerationCallback -WillSynthesizeAssistantSpeech = Callable[ +PreTTSGenerationCallback = Callable[ ["VoiceAssistant", Union[str, AsyncIterable[str]]], Union[str, AsyncIterable[str], Awaitable[str]], ] @@ -73,13 +76,13 @@ def llm_stream(self) -> LLMStream: return self._llm_stream -def _default_will_synthesize_assistant_reply( +def _default_llm_generation_cb( assistant: VoiceAssistant, chat_ctx: ChatContext ) -> LLMStream: return assistant.llm.chat(chat_ctx=chat_ctx, fnc_ctx=assistant.fnc_ctx) -def _default_will_synthesize_assistant_speech( +def _default_tts_generation_cb( assistant: VoiceAssistant, text: str | AsyncIterable[str] ) -> str | AsyncIterable[str]: return text @@ -91,8 +94,8 @@ class _ImplOptions: int_speech_duration: float int_min_words: int preemptive_synthesis: bool - will_synthesize_assistant_reply: WillSynthesizeAssistantReply - will_synthesize_assistant_speech: WillSynthesizeAssistantSpeech + pre_llm_generation_cb: PreLLMGenerationCallback + pre_tts_generation_cb: PreTTSGenerationCallback plotting: bool transcription: AssistantTranscriptionOptions @@ -137,10 +140,12 @@ def __init__( interrupt_min_words: int = 0, preemptive_synthesis: bool = True, transcription: AssistantTranscriptionOptions = AssistantTranscriptionOptions(), - will_synthesize_assistant_reply: WillSynthesizeAssistantReply = _default_will_synthesize_assistant_reply, - will_synthesize_assistant_speech: WillSynthesizeAssistantSpeech = _default_will_synthesize_assistant_speech, + pre_llm_generation_cb: PreLLMGenerationCallback = _default_llm_generation_cb, + pre_tts_generation_cb: PreTTSGenerationCallback = _default_tts_generation_cb, plotting: bool = False, loop: asyncio.AbstractEventLoop | None = None, + # backward compatibility + will_synthesize_assistant_reply: PreLLMGenerationCallback | None = None, ) -> None: """ Create a new VoiceAssistant. @@ -158,9 +163,9 @@ def __init__( Defaults to 0 as this may increase the latency depending on the STT. preemptive_synthesis: Whether to preemptively synthesize responses. transcription: Options for assistant transcription. - will_synthesize_assistant_reply: Callback called when the assistant is about to synthesize a reply. + pre_llm_generation_cb: Callback called when the assistant is about to synthesize a reply. This can be used to customize the reply (e.g: inject context/RAG). - will_synthesize_assistant_speech: Callback called when the assistant is about to + pre_tts_generation_cb: Callback called when the assistant is about to synthesize a speech. This can be used to customize text before the speech synthesis. (e.g: editing the pronunciation of a word). plotting: Whether to enable plotting for debugging. matplotlib must be installed. @@ -168,6 +173,15 @@ def __init__( """ super().__init__() self._loop = loop or asyncio.get_event_loop() + + if will_synthesize_assistant_reply is not None: + warnings.warn( + "will_synthesize_assistant_reply is deprecated and will be removed in 1.5.0, use on_generate_llm instead", + DeprecationWarning, + stacklevel=2, + ) + on_generate_llm = will_synthesize_assistant_reply + self._opts = _ImplOptions( plotting=plotting, allow_interruptions=allow_interruptions, @@ -175,8 +189,8 @@ def __init__( int_min_words=interrupt_min_words, preemptive_synthesis=preemptive_synthesis, transcription=transcription, - will_synthesize_assistant_reply=will_synthesize_assistant_reply, - will_synthesize_assistant_speech=will_synthesize_assistant_speech, + pre_llm_generation_cb=pre_llm_generation_cb, + pre_tts_generation_cb=pre_tts_generation_cb, ) self._plotter = AssistantPlotter(self._loop) @@ -524,15 +538,13 @@ async def _synthesize_answer_task( ChatMessage.create(text=handle.user_question, role="user") ) - llm_stream = self._opts.will_synthesize_assistant_reply(self, copied_ctx) + llm_stream = self._opts.pre_llm_generation_cb(self, copied_ctx) if asyncio.iscoroutine(llm_stream): llm_stream = await llm_stream # fallback to default impl if no custom/user stream is returned if not isinstance(llm_stream, LLMStream): - llm_stream = _default_will_synthesize_assistant_reply( - self, chat_ctx=copied_ctx - ) + llm_stream = _default_llm_generation_cb(self, chat_ctx=copied_ctx) synthesis_handle = self._synthesize_agent_speech(handle.id, llm_stream) handle.initialize(source=llm_stream, synthesis_handle=synthesis_handle) @@ -731,7 +743,9 @@ def _synthesize_agent_speech( if isinstance(source, LLMStream): source = _llm_stream_to_str_iterable(speech_id, source) - speech_source = self._opts.will_synthesize_assistant_speech(self, source) + speech_source = self._opts.pre_tts_generation_cb(self, source) + if speech_source is None: + logger.error("pre_tts_generation_cb must return str or AsyncIterable[str]") return self._agent_output.synthesize( speech_id=speech_id, From 4d22bd2e7bf60a44391110e020e2583d658a2fd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Fri, 6 Sep 2024 16:28:44 -0700 Subject: [PATCH 06/13] wip --- .../voice-assistant/custom_pronunciation.py | 5 +-- .../agents/transcription/stt_forwarder.py | 30 +++++++++----- .../agents/transcription/tts_forwarder.py | 30 ++++++++++---- .../livekit/agents/utils/__init__.py | 2 +- .../agents/voice_assistant/voice_assistant.py | 41 +++++++++---------- 5 files changed, 64 insertions(+), 44 deletions(-) diff --git a/examples/voice-assistant/custom_pronunciation.py b/examples/voice-assistant/custom_pronunciation.py index 15158a4dc..45236ba22 100644 --- a/examples/voice-assistant/custom_pronunciation.py +++ b/examples/voice-assistant/custom_pronunciation.py @@ -1,12 +1,11 @@ from __future__ import annotations + from typing import AsyncIterable -import asyncio from dotenv import load_dotenv -from livekit import rtc from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm, utils from livekit.agents.voice_assistant import VoiceAssistant -from livekit.plugins import deepgram, cartesia, openai, silero +from livekit.plugins import cartesia, deepgram, openai, silero load_dotenv() diff --git a/livekit-agents/livekit/agents/transcription/stt_forwarder.py b/livekit-agents/livekit/agents/transcription/stt_forwarder.py index 410b343eb..3ec7d79ec 100644 --- a/livekit-agents/livekit/agents/transcription/stt_forwarder.py +++ b/livekit-agents/livekit/agents/transcription/stt_forwarder.py @@ -2,6 +2,7 @@ import asyncio import contextlib +import warnings from typing import Awaitable, Callable, Optional, Union from livekit import rtc @@ -10,13 +11,16 @@ from ..log import logger from . import _utils -WillForwardTranscription = Callable[ +BeforeForwardCallback = Callable[ ["STTSegmentsForwarder", rtc.Transcription], Union[rtc.Transcription, Awaitable[Optional[rtc.Transcription]]], ] -def _default_will_forward_transcription( +WillForwardTranscription = BeforeForwardCallback + + +def _default_before_forward_cb( fwd: STTSegmentsForwarder, transcription: rtc.Transcription ) -> rtc.Transcription: return transcription @@ -33,7 +37,9 @@ def __init__( room: rtc.Room, participant: rtc.Participant | str, track: rtc.Track | rtc.TrackPublication | str | None = None, - will_forward_transcription: WillForwardTranscription = _default_will_forward_transcription, + before_forward_cb: BeforeForwardCallback = _default_before_forward_cb, + # backward compatibility + will_forward_transcription: WillForwardTranscription | None = None, ): identity = participant if isinstance(participant, str) else participant.identity if track is None: @@ -41,8 +47,16 @@ def __init__( elif isinstance(track, (rtc.TrackPublication, rtc.Track)): track = track.sid + if will_forward_transcription is not None: + warnings.warn( + "will_forward_transcription is deprecated and will be removed in 1.5.0, use before_forward_cb instead", + DeprecationWarning, + stacklevel=2, + ) + before_forward_cb = will_forward_transcription + self._room, self._participant_identity, self._track_id = room, identity, track - self._will_forward_transcription = will_forward_transcription + self._before_forward_cb = before_forward_cb self._queue = asyncio.Queue[Optional[rtc.TranscriptionSegment]]() self._main_task = asyncio.create_task(self._run()) self._current_id = _utils.segment_uuid() @@ -60,16 +74,12 @@ async def _run(self): segments=[seg], # no history for now ) - transcription = self._will_forward_transcription( - self, base_transcription - ) + transcription = self._before_forward_cb(self, base_transcription) if asyncio.iscoroutine(transcription): transcription = await transcription if not isinstance(transcription, rtc.Transcription): - transcription = _default_will_forward_transcription( - self, base_transcription - ) + transcription = _default_before_forward_cb(self, base_transcription) if transcription.segments and self._room.isconnected(): await self._room.local_participant.publish_transcription( diff --git a/livekit-agents/livekit/agents/transcription/tts_forwarder.py b/livekit-agents/livekit/agents/transcription/tts_forwarder.py index 45125c7a7..87c88723a 100644 --- a/livekit-agents/livekit/agents/transcription/tts_forwarder.py +++ b/livekit-agents/livekit/agents/transcription/tts_forwarder.py @@ -3,6 +3,7 @@ import asyncio import contextlib import time +import warnings from collections import deque from dataclasses import dataclass from typing import Awaitable, Callable, Deque, Optional, Union @@ -18,13 +19,16 @@ STANDARD_SPEECH_RATE = 3.83 -WillForwardTranscription = Callable[ +BeforeForwardCallback = Callable[ ["TTSSegmentsForwarder", rtc.Transcription], Union[rtc.Transcription, Awaitable[Optional[rtc.Transcription]]], ] -def _default_will_forward_transcription( +WillForwardTranscription = BeforeForwardCallback + + +def _default_before_forward_callback( fwd: TTSSegmentsForwarder, transcription: rtc.Transcription ) -> rtc.Transcription: return transcription @@ -41,7 +45,7 @@ class _TTSOptions: sentence_tokenizer: tokenize.SentenceTokenizer hyphenate_word: Callable[[str], list[str]] new_sentence_delay: float - will_forward_transcription: WillForwardTranscription + before_forward_cb: BeforeForwardCallback @dataclass @@ -84,8 +88,10 @@ def __init__( word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(), sentence_tokenizer: tokenize.SentenceTokenizer = tokenize.basic.SentenceTokenizer(), hyphenate_word: Callable[[str], list[str]] = tokenize.basic.hyphenate_word, - will_forward_transcription: WillForwardTranscription = _default_will_forward_transcription, + before_forward_cb: BeforeForwardCallback = _default_before_forward_callback, loop: asyncio.AbstractEventLoop | None = None, + # backward compatibility + will_forward_transcription: WillForwardTranscription | None = None, ): """ Args: @@ -110,6 +116,14 @@ def __init__( elif isinstance(track, (rtc.TrackPublication, rtc.Track)): track = track.sid + if will_forward_transcription is not None: + warnings.warn( + "will_forward_transcription is deprecated and will be removed in 1.5.0, use before_forward_cb instead", + DeprecationWarning, + stacklevel=2, + ) + before_forward_cb = will_forward_transcription + speed = speed * STANDARD_SPEECH_RATE self._opts = _TTSOptions( room=room, @@ -121,7 +135,7 @@ def __init__( sentence_tokenizer=sentence_tokenizer, hyphenate_word=hyphenate_word, new_sentence_delay=new_sentence_delay, - will_forward_transcription=will_forward_transcription, + before_forward_cb=before_forward_cb, ) self._closed = False self._loop = loop or asyncio.get_event_loop() @@ -247,15 +261,13 @@ async def _forward_task(): segments=[seg], # no history for now ) - transcription = self._opts.will_forward_transcription( - self, base_transcription - ) + transcription = self._opts.before_forward_cb(self, base_transcription) if asyncio.iscoroutine(transcription): transcription = await transcription # fallback to default impl if no custom/user stream is returned if not isinstance(transcription, rtc.Transcription): - transcription = _default_will_forward_transcription( + transcription = _default_before_forward_callback( self, base_transcription ) diff --git a/livekit-agents/livekit/agents/utils/__init__.py b/livekit-agents/livekit/agents/utils/__init__.py index c139cf0c9..5beffa3ea 100644 --- a/livekit-agents/livekit/agents/utils/__init__.py +++ b/livekit-agents/livekit/agents/utils/__init__.py @@ -2,7 +2,7 @@ from .event_emitter import EventEmitter from .exp_filter import ExpFilter from .log import log_exceptions -from .misc import AudioBuffer, merge_frames, shortuuid, time_ms, replace_words +from .misc import AudioBuffer, merge_frames, replace_words, shortuuid, time_ms from .moving_average import MovingAverage __all__ = [ diff --git a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py index 197f42e8d..1637bb87e 100644 --- a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py +++ b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py @@ -1,9 +1,9 @@ from __future__ import annotations -import warnings import asyncio import contextvars import time +import warnings from dataclasses import dataclass from typing import Any, AsyncIterable, Awaitable, Callable, Literal, Optional, Union @@ -18,15 +18,14 @@ from .plotter import AssistantPlotter from .speech_handle import SpeechHandle - -PreLLMGenerationCallback = Callable[ +BeforeLLMCallback = Callable[ ["VoiceAssistant", ChatContext], Union[Optional[LLMStream], Awaitable[Optional[LLMStream]]], ] -WillSynthesizeAssistantReply = PreLLMGenerationCallback +WillSynthesizeAssistantReply = BeforeLLMCallback -PreTTSGenerationCallback = Callable[ +BeforeTTSCallback = Callable[ ["VoiceAssistant", Union[str, AsyncIterable[str]]], Union[str, AsyncIterable[str], Awaitable[str]], ] @@ -76,13 +75,13 @@ def llm_stream(self) -> LLMStream: return self._llm_stream -def _default_llm_generation_cb( +def _default_before_llm_cb( assistant: VoiceAssistant, chat_ctx: ChatContext ) -> LLMStream: return assistant.llm.chat(chat_ctx=chat_ctx, fnc_ctx=assistant.fnc_ctx) -def _default_tts_generation_cb( +def _default_before_tts_cb( assistant: VoiceAssistant, text: str | AsyncIterable[str] ) -> str | AsyncIterable[str]: return text @@ -94,8 +93,8 @@ class _ImplOptions: int_speech_duration: float int_min_words: int preemptive_synthesis: bool - pre_llm_generation_cb: PreLLMGenerationCallback - pre_tts_generation_cb: PreTTSGenerationCallback + before_llm_cb: BeforeLLMCallback + before_tts_cb: BeforeTTSCallback plotting: bool transcription: AssistantTranscriptionOptions @@ -140,12 +139,12 @@ def __init__( interrupt_min_words: int = 0, preemptive_synthesis: bool = True, transcription: AssistantTranscriptionOptions = AssistantTranscriptionOptions(), - pre_llm_generation_cb: PreLLMGenerationCallback = _default_llm_generation_cb, - pre_tts_generation_cb: PreTTSGenerationCallback = _default_tts_generation_cb, + before_llm_cb: BeforeLLMCallback = _default_before_llm_cb, + before_tts_cb: BeforeTTSCallback = _default_before_tts_cb, plotting: bool = False, loop: asyncio.AbstractEventLoop | None = None, # backward compatibility - will_synthesize_assistant_reply: PreLLMGenerationCallback | None = None, + will_synthesize_assistant_reply: WillSynthesizeAssistantReply | None = None, ) -> None: """ Create a new VoiceAssistant. @@ -163,9 +162,9 @@ def __init__( Defaults to 0 as this may increase the latency depending on the STT. preemptive_synthesis: Whether to preemptively synthesize responses. transcription: Options for assistant transcription. - pre_llm_generation_cb: Callback called when the assistant is about to synthesize a reply. + before_llm_cb: Callback called when the assistant is about to synthesize a reply. This can be used to customize the reply (e.g: inject context/RAG). - pre_tts_generation_cb: Callback called when the assistant is about to + before_tts_cb: Callback called when the assistant is about to synthesize a speech. This can be used to customize text before the speech synthesis. (e.g: editing the pronunciation of a word). plotting: Whether to enable plotting for debugging. matplotlib must be installed. @@ -176,11 +175,11 @@ def __init__( if will_synthesize_assistant_reply is not None: warnings.warn( - "will_synthesize_assistant_reply is deprecated and will be removed in 1.5.0, use on_generate_llm instead", + "will_synthesize_assistant_reply is deprecated and will be removed in 1.5.0, use before_llm_cb instead", DeprecationWarning, stacklevel=2, ) - on_generate_llm = will_synthesize_assistant_reply + before_llm_cb = will_synthesize_assistant_reply self._opts = _ImplOptions( plotting=plotting, @@ -189,8 +188,8 @@ def __init__( int_min_words=interrupt_min_words, preemptive_synthesis=preemptive_synthesis, transcription=transcription, - pre_llm_generation_cb=pre_llm_generation_cb, - pre_tts_generation_cb=pre_tts_generation_cb, + before_llm_cb=before_llm_cb, + before_tts_cb=before_tts_cb, ) self._plotter = AssistantPlotter(self._loop) @@ -538,13 +537,13 @@ async def _synthesize_answer_task( ChatMessage.create(text=handle.user_question, role="user") ) - llm_stream = self._opts.pre_llm_generation_cb(self, copied_ctx) + llm_stream = self._opts.before_llm_cb(self, copied_ctx) if asyncio.iscoroutine(llm_stream): llm_stream = await llm_stream # fallback to default impl if no custom/user stream is returned if not isinstance(llm_stream, LLMStream): - llm_stream = _default_llm_generation_cb(self, chat_ctx=copied_ctx) + llm_stream = _default_before_llm_cb(self, chat_ctx=copied_ctx) synthesis_handle = self._synthesize_agent_speech(handle.id, llm_stream) handle.initialize(source=llm_stream, synthesis_handle=synthesis_handle) @@ -743,7 +742,7 @@ def _synthesize_agent_speech( if isinstance(source, LLMStream): source = _llm_stream_to_str_iterable(speech_id, source) - speech_source = self._opts.pre_tts_generation_cb(self, source) + speech_source = self._opts.before_tts_cb(self, source) if speech_source is None: logger.error("pre_tts_generation_cb must return str or AsyncIterable[str]") From eba9b2ce15b0e048c42660a4a78c5d05f098370d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Fri, 6 Sep 2024 18:48:01 -0700 Subject: [PATCH 07/13] wip --- .../livekit/agents/tokenize/__init__.py | 3 +- .../livekit/agents/tokenize/utils.py | 89 +++++++++++++++++++ .../livekit/agents/utils/__init__.py | 3 +- livekit-agents/livekit/agents/utils/misc.py | 53 +---------- tests/test_tokenizer.py | 32 +++++-- 5 files changed, 120 insertions(+), 60 deletions(-) create mode 100644 livekit-agents/livekit/agents/tokenize/utils.py diff --git a/livekit-agents/livekit/agents/tokenize/__init__.py b/livekit-agents/livekit/agents/tokenize/__init__.py index 1a9eafb57..5b18d0e29 100644 --- a/livekit-agents/livekit/agents/tokenize/__init__.py +++ b/livekit-agents/livekit/agents/tokenize/__init__.py @@ -1,4 +1,4 @@ -from . import basic +from . import basic, utils from .token_stream import ( BufferedSentenceStream, BufferedWordStream, @@ -20,4 +20,5 @@ "BufferedSentenceStream", "BufferedWordStream", "basic", + "utils", ] diff --git a/livekit-agents/livekit/agents/tokenize/utils.py b/livekit-agents/livekit/agents/tokenize/utils.py new file mode 100644 index 000000000..87dfaec37 --- /dev/null +++ b/livekit-agents/livekit/agents/tokenize/utils.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import AsyncIterable, overload + +from . import _basic_word, tokenizer + + +@overload +def replace_words( + *, + text: str, + replacements: dict[str, str], +) -> str: ... + + +@overload +def replace_words( + *, + text: AsyncIterable[str], + replacements: dict[str, str], +) -> AsyncIterable[str]: ... + + +def replace_words( + *, + text: str | AsyncIterable[str], + replacements: dict[str, str], +) -> str | AsyncIterable[str]: + """ + Replace words in the given (async) text. The replacements are case-insensitive and the + replacement will keep the case of the original word. + Args: + text: text to replace words in + words: dictionary of words to replace + """ + + replacements = {k.lower(): v for k, v in replacements.items()} + + def _match_case(word, replacement): + if word.isupper(): + return replacement.upper() + elif word.istitle(): + return replacement.title() + else: + return replacement.lower() + + def _process_words(text, words): + offset = 0 + processed_index = 0 + for word, start_index, end_index in words: + no_punctuation = word.rstrip("".join(tokenizer.PUNCTUATIONS)) + punctuation_off = len(word) - len(no_punctuation) + replacement = replacements.get(no_punctuation.lower()) + if replacement: + text = ( + text[: start_index + offset] + + _match_case(word, replacement) + + text[end_index + offset - punctuation_off :] + ) + offset += len(replacement) - len(word) + punctuation_off + processed_index = end_index + offset + + return text, processed_index + + if isinstance(text, str): + words = _basic_word.split_words(text, ignore_punctuation=False) + text, _ = _process_words(text, words) + return text + else: + + async def _replace_words(): + buffer = "" + async for chunk in text: + buffer += chunk + words = _basic_word.split_words(buffer, ignore_punctuation=False) + + if len(words) <= 1: + continue + + buffer, procesed_index = _process_words(buffer, words) + yield buffer[:procesed_index] + buffer = buffer[procesed_index:] + + if buffer: + words = _basic_word.split_words(buffer, ignore_punctuation=False) + buffer, _ = _process_words(buffer, words) + yield buffer + + return _replace_words() diff --git a/livekit-agents/livekit/agents/utils/__init__.py b/livekit-agents/livekit/agents/utils/__init__.py index 5beffa3ea..361308572 100644 --- a/livekit-agents/livekit/agents/utils/__init__.py +++ b/livekit-agents/livekit/agents/utils/__init__.py @@ -2,7 +2,7 @@ from .event_emitter import EventEmitter from .exp_filter import ExpFilter from .log import log_exceptions -from .misc import AudioBuffer, merge_frames, replace_words, shortuuid, time_ms +from .misc import AudioBuffer, merge_frames, shortuuid, time_ms from .moving_average import MovingAverage __all__ = [ @@ -10,7 +10,6 @@ "merge_frames", "time_ms", "shortuuid", - "replace_words", "http_context", "ExpFilter", "MovingAverage", diff --git a/livekit-agents/livekit/agents/utils/misc.py b/livekit-agents/livekit/agents/utils/misc.py index a926e39f1..f85ae15b7 100644 --- a/livekit-agents/livekit/agents/utils/misc.py +++ b/livekit-agents/livekit/agents/utils/misc.py @@ -2,7 +2,7 @@ import time import uuid -from typing import AsyncIterable, List, Union, overload +from typing import List, Union from livekit import rtc @@ -50,54 +50,3 @@ def time_ms() -> int: def shortuuid() -> str: return str(uuid.uuid4().hex)[:12] - - -@overload -def replace_words( - *, - text: str, - replacements: dict[str, str], -) -> str: ... - - -@overload -def replace_words( - *, - text: AsyncIterable[str], - replacements: dict[str, str], -) -> AsyncIterable[str]: ... - - -def replace_words( - *, - text: str | AsyncIterable[str], - replacements: dict[str, str], -) -> str | AsyncIterable[str]: - """ - Replace words in text with another str - Args: - text: text to replace words in - words: dictionary of words to replace - """ - if isinstance(text, str): - for word, replacement in replacements.items(): - text = text.replace(word, replacement) - return text - else: - - async def _replace_words(): - buffer = "" - async for chunk in text: - for char in chunk: - if not char.isspace(): - buffer += char - else: - if buffer: - yield replacements.get(buffer, buffer) - buffer = "" - yield char - - if buffer: - yield replacements.get(buffer, buffer) - - return _replace_words() diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 692878939..7b1d79141 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,5 +1,5 @@ import pytest -from livekit.agents import tokenize, utils +from livekit.agents import tokenize from livekit.agents.tokenize import basic from livekit.plugins import nltk @@ -143,23 +143,31 @@ def test_hyphenate_word(): assert hyphenated == HYPHENATOR_EXPECTED[i] -REPLACE_TEXT = "This is a test. Hello world, I'm creating this agents.. framework" +REPLACE_TEXT = ( + "This is a test. Hello world, I'm creating this agents.. framework. Once again " + "framework. A.B.C" +) REPLACE_EXPECTED = ( - "This is a test. Hello universe, I'm creating this agents.. library" + "This is a test. Hello universe, I'm creating this agents.. library. Twice again " + "library. A.B.C.D" ) REPLACE_REPLACEMENTS = { "world": "universe", "framework": "library", + "a.b.c": "A.B.C.D", + "once": "twice", } def test_replace_words(): - replaced = utils.replace_words(text=REPLACE_TEXT, replacements=REPLACE_REPLACEMENTS) + replaced = tokenize.utils.replace_words( + text=REPLACE_TEXT, replacements=REPLACE_REPLACEMENTS + ) assert replaced == REPLACE_EXPECTED -async def text_replace_words_async(): +async def test_replace_words_async(): pattern = [1, 2, 4] text = REPLACE_TEXT chunks = [] @@ -170,3 +178,17 @@ async def text_replace_words_async(): break chunks.append(text[:chunk_size]) text = text[chunk_size:] + + async def _replace_words_async(): + for chunk in chunks: + yield chunk + + replaced_chunks = [] + + async for chunk in tokenize.utils.replace_words( + text=_replace_words_async(), replacements=REPLACE_REPLACEMENTS + ): + replaced_chunks.append(chunk) + + replaced = "".join(replaced_chunks) + assert replaced == REPLACE_EXPECTED From 8dbbf21a61362e551b8b0ffffa46e85ce481cfff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Fri, 6 Sep 2024 19:17:46 -0700 Subject: [PATCH 08/13] fix example --- .../voice-assistant/custom_pronunciation.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/voice-assistant/custom_pronunciation.py b/examples/voice-assistant/custom_pronunciation.py index 45236ba22..e6ff7cd52 100644 --- a/examples/voice-assistant/custom_pronunciation.py +++ b/examples/voice-assistant/custom_pronunciation.py @@ -3,7 +3,7 @@ from typing import AsyncIterable from dotenv import load_dotenv -from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm, utils +from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm, tokenize from livekit.agents.voice_assistant import VoiceAssistant from livekit.plugins import cartesia, deepgram, openai, silero @@ -21,18 +21,16 @@ async def entrypoint(ctx: JobContext): await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) - async def _will_synthesize_assistant_speech( - assistant: VoiceAssistant, text: str | AsyncIterable[str] - ): - # Cartesia TTS is incorrectly pronouncing "LiveKit", so we'll replace it with a phonetic + def _before_tts_cb(assistant: VoiceAssistant, text: str | AsyncIterable[str]): + # The TTS is incorrectly pronouncing "LiveKit", so we'll replace it with a phonetic # spelling - return utils.replace_words( - text=text, replacements={"LiveKit": "<>Kit"} + return tokenize.utils.replace_words( + text=text, replacements={"livekit": r"<>"} ) # also for this example, we also intensify the keyword "LiveKit" to make it more likely to be # recognized with the STT - deepgram_stt = deepgram.STT(keywords=[("LiveKit", 2.0)]) + deepgram_stt = deepgram.STT(keywords=[("LiveKit", 3.5)]) assistant = VoiceAssistant( vad=silero.VAD.load(), @@ -40,11 +38,11 @@ async def _will_synthesize_assistant_speech( llm=openai.LLM(), tts=cartesia.TTS(), chat_ctx=initial_ctx, - will_synthesize_assistant_speech=_will_synthesize_assistant_speech, + before_tts_cb=_before_tts_cb, ) assistant.start(ctx.room) - await assistant.say("Hey, how can I help you today?", allow_interruptions=True) + await assistant.say("Hey, LiveKit is awesome!", allow_interruptions=True) if __name__ == "__main__": From 691c731eb93d0ca6a90d900647b14f6946b368a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Sat, 7 Sep 2024 22:57:26 -0700 Subject: [PATCH 09/13] tts_forwarder: remove ordering constraint --- .../livekit/agents/tokenize/utils.py | 5 +- .../agents/transcription/tts_forwarder.py | 232 ++++++++++-------- .../livekit/agents/utils/aio/__init__.py | 3 +- .../livekit/agents/utils/aio/itertools.py | 114 +++++++++ .../agents/voice_assistant/agent_output.py | 87 ++++--- .../agents/voice_assistant/voice_assistant.py | 14 +- 6 files changed, 310 insertions(+), 145 deletions(-) create mode 100644 livekit-agents/livekit/agents/utils/aio/itertools.py diff --git a/livekit-agents/livekit/agents/tokenize/utils.py b/livekit-agents/livekit/agents/tokenize/utils.py index 87dfaec37..912941943 100644 --- a/livekit-agents/livekit/agents/tokenize/utils.py +++ b/livekit-agents/livekit/agents/tokenize/utils.py @@ -58,7 +58,8 @@ def _process_words(text, words): + text[end_index + offset - punctuation_off :] ) offset += len(replacement) - len(word) + punctuation_off - processed_index = end_index + offset + + processed_index = end_index + offset return text, processed_index @@ -77,7 +78,7 @@ async def _replace_words(): if len(words) <= 1: continue - buffer, procesed_index = _process_words(buffer, words) + buffer, procesed_index = _process_words(buffer, words[:-1]) yield buffer[:procesed_index] buffer = buffer[procesed_index:] diff --git a/livekit-agents/livekit/agents/transcription/tts_forwarder.py b/livekit-agents/livekit/agents/transcription/tts_forwarder.py index 87c88723a..88f8d6d94 100644 --- a/livekit-agents/livekit/agents/transcription/tts_forwarder.py +++ b/livekit-agents/livekit/agents/transcription/tts_forwarder.py @@ -4,11 +4,11 @@ import contextlib import time import warnings -from collections import deque from dataclasses import dataclass -from typing import Awaitable, Callable, Deque, Optional, Union +from typing import Awaitable, Callable, Optional, Union from livekit import rtc +from livekit.rtc.participant import PublishTranscriptionError from .. import tokenize, utils from ..log import logger @@ -49,23 +49,19 @@ class _TTSOptions: @dataclass -class _SegmentData: - segment_index: int - sentence_stream: tokenize.SentenceStream - pushed_text: str = "" +class _AudioData: pushed_duration: float = 0.0 - real_speed: float | None = None - processed_sentences: int = 0 - processed_hyphens: int = 0 - validated: bool = False - forward_start_time: float | None = 0.0 + done: bool = False @dataclass -class _FormingSegments: - audio: _SegmentData - text: _SegmentData - q: deque[_SegmentData] +class _TextData: + sentence_stream: tokenize.SentenceStream + pushed_text: str = "" + done: bool = False + + forwarded_hyphens: int = 0 + forwarded_sentences: int = 0 class TTSSegmentsForwarder: @@ -141,25 +137,22 @@ def __init__( self._loop = loop or asyncio.get_event_loop() self._close_future = asyncio.Future[None]() - self._next_segment_index = 0 self._playing_seg_index = -1 self._finshed_seg_index = -1 - first_segment = self._create_segment() - segments_q: Deque[_SegmentData] = deque() - segments_q.append(first_segment) + self._text_q_changed = asyncio.Event() + self._text_q = list[Union[_TextData, None]]() + self._audio_q_changed = asyncio.Event() + self._audio_q = list[Union[_AudioData, None]]() - self._forming_segments = _FormingSegments( - audio=first_segment, text=first_segment, q=segments_q - ) + self._text_data: _TextData | None = None + self._audio_data: _AudioData | None = None + + self._played_text = "" - self._seg_queue = asyncio.Queue[Optional[_SegmentData]]() - self._seg_queue.put_nowait(first_segment) self._main_atask = self._loop.create_task(self._main_task()) self._task_set = utils.aio.TaskSet(loop) - self._played_text = "" - def segment_playout_started(self) -> None: """ Notify that the playout of the audio segment has started. @@ -179,47 +172,46 @@ def segment_playout_finished(self) -> None: def push_audio(self, frame: rtc.AudioFrame) -> None: self._check_not_closed() + + if self._audio_data is None: + self._audio_data = _AudioData() + self._audio_q.append(self._audio_data) + self._audio_q_changed.set() + frame_duration = frame.samples_per_channel / frame.sample_rate - cur_seg = self._forming_segments.audio - cur_seg.pushed_duration += frame_duration - cur_seg.validated = True + self._audio_data.pushed_duration += frame_duration def mark_audio_segment_end(self) -> None: self._check_not_closed() - try: - # get last ended segment (text always end before audio) - seg = self._forming_segments.q.popleft() - except IndexError: - raise IndexError( - "mark_audio_segment_end called before any mark_text_segment_end" - ) - if seg.pushed_duration > 0.0: - seg.real_speed = ( - len(self._calc_hyphens(seg.pushed_text)) / seg.pushed_duration - ) + if self._audio_data is None: + raise RuntimeError("mark_audio_segment_end called before any push_audio") - seg.validated = True - self._forming_segments.audio = self._forming_segments.q[0] + self._audio_data.done = True + self._audio_data = None def push_text(self, text: str) -> None: self._check_not_closed() - cur_seg = self._forming_segments.text - cur_seg.pushed_text += text - cur_seg.sentence_stream.push_text(text) + + if self._text_data is None: + self._text_data = _TextData( + sentence_stream=self._opts.sentence_tokenizer.stream() + ) + self._text_q.append(self._text_data) + self._text_q_changed.set() + + self._text_data.pushed_text += text + self._text_data.sentence_stream.push_text(text) def mark_text_segment_end(self) -> None: self._check_not_closed() - stream = self._forming_segments.text.sentence_stream - stream.end_input() - # create a new segment on "mark_text_segment_end" - # further text can already be pushed even if mark_audio_segment_end has not been - # called yet - new_seg = self._create_segment() - self._forming_segments.text = new_seg - self._forming_segments.q.append(new_seg) - self._seg_queue.put_nowait(new_seg) + if self._text_data is None: + raise RuntimeError("mark_text_segment_end called before any push_text") + + self._text_data.done = True + self._text_data.sentence_stream.end_input() + self._text_data = None @property def closed(self) -> bool: @@ -235,10 +227,15 @@ async def aclose(self) -> None: self._closed = True self._close_future.set_result(None) - self._seg_queue.put_nowait(None) - for seg in self._forming_segments.q: - await seg.sentence_stream.aclose() + for text_data in self._text_q: + assert text_data is not None + await text_data.sentence_stream.aclose() + + self._text_q.append(None) + self._audio_q.append(None) + self._text_q_changed.set() + self._audio_q_changed.set() await self._task_set.aclose() await self._main_atask @@ -246,19 +243,15 @@ async def aclose(self) -> None: @utils.log_exceptions(logger=logger) async def _main_task(self) -> None: """Main task that forwards the transcription to the room.""" - rtc_seg_q = asyncio.Queue[Optional[rtc.TranscriptionSegment]]() + rtc_seg_ch = utils.aio.Chan[rtc.TranscriptionSegment]() @utils.log_exceptions(logger=logger) async def _forward_task(): - while True: - seg = await rtc_seg_q.get() - if seg is None: - break - + async for rtc_seg in rtc_seg_ch: base_transcription = rtc.Transcription( participant_identity=self._opts.participant_identity, track_sid=self._opts.track_id, - segments=[seg], # no history for now + segments=[rtc_seg], # no history for now ) transcription = self._opts.before_forward_cb(self, base_transcription) @@ -272,50 +265,83 @@ async def _forward_task(): ) if transcription.segments and self._opts.room.isconnected(): - await self._opts.room.local_participant.publish_transcription( - transcription - ) + try: + await self._opts.room.local_participant.publish_transcription( + transcription + ) + except PublishTranscriptionError: + continue forward_task = asyncio.create_task(_forward_task()) - while True: - seg = await self._seg_queue.get() - if seg is None: - break + seg_index = 0 + q_done = False + while not q_done: + await self._text_q_changed.wait() + await self._audio_q_changed.wait() - # wait until the segment is validated and has started playing - while not self._closed: - if seg.validated and self._playing_seg_index >= seg.segment_index: + while self._text_q and self._audio_q: + text_data = self._text_q.pop(0) + audio_data = self._audio_q.pop(0) + + if text_data is None or audio_data is None: + q_done = True break - await self._sleep_if_not_closed(0.1) + # wait until the segment is validated and has started playing + while not self._closed: + if self._playing_seg_index >= seg_index: + break + + await self._sleep_if_not_closed(0.125) + + sentence_stream = text_data.sentence_stream + forward_start_time = time.time() + + async for ev in sentence_stream: + await self._sync_sentence_co( + seg_index, + forward_start_time, + text_data, + audio_data, + ev.token, + rtc_seg_ch, + ) - sentence_stream = seg.sentence_stream - seg.forward_start_time = time.time() + seg_index += 1 - async for ev in sentence_stream: - await self._sync_sentence_co(seg, ev.token, rtc_seg_q) + self._text_q_changed.clear() + self._audio_q_changed.clear() - rtc_seg_q.put_nowait(None) + rtc_seg_ch.close() await forward_task async def _sync_sentence_co( self, - seg: _SegmentData, - tokenized_sentence: str, - rtc_seg_q: asyncio.Queue[Optional[rtc.TranscriptionSegment]], + segment_index: int, + segment_start_time: float, + text_data: _TextData, + audio_data: _AudioData, + sentence: str, + rtc_seg_ch: utils.aio.Chan[rtc.TranscriptionSegment], ): """Synchronize the transcription with the audio playout for a given sentence.""" - assert seg.forward_start_time is not None - # put each sentence in a different transcription segment + + real_speed = None + if audio_data.pushed_duration > 0 and audio_data.done: + real_speed = ( + len(self._calc_hyphens(text_data.pushed_text)) + / audio_data.pushed_duration + ) + seg_id = _utils.segment_uuid() - words = self._opts.word_tokenizer.tokenize(text=tokenized_sentence) + words = self._opts.word_tokenizer.tokenize(text=sentence) processed_words: list[str] = [] og_text = self._played_text for word in words: - if seg.segment_index <= self._finshed_seg_index: + if segment_index <= self._finshed_seg_index: # playout of the audio segment already finished # break the loop and send the final transcription break @@ -328,22 +354,22 @@ async def _sync_sentence_co( processed_words.append(word) # elapsed time since the start of the seg - elapsed_time = time.time() - seg.forward_start_time + elapsed_time = time.time() - segment_start_time text = self._opts.word_tokenizer.format_words(processed_words) # remove any punctuation at the end of a non-final transcript text = text.rstrip("".join(PUNCTUATIONS)) speed = self._opts.speed - if seg.real_speed is not None: - speed = seg.real_speed + if real_speed is not None: + speed = real_speed estimated_pauses_s = ( - seg.processed_sentences * self._opts.new_sentence_delay + text_data.forwarded_sentences * self._opts.new_sentence_delay ) hyph_pauses = estimated_pauses_s * speed target_hyphens = round(speed * elapsed_time) - dt = target_hyphens - seg.processed_hyphens - hyph_pauses + dt = target_hyphens - text_data.forwarded_hyphens - hyph_pauses to_wait_hyphens = max(0.0, word_hyphens - dt) delay = to_wait_hyphens / speed else: @@ -351,7 +377,8 @@ async def _sync_sentence_co( first_delay = min(delay / 2, 2 / speed) await self._sleep_if_not_closed(first_delay) - rtc_seg_q.put_nowait( + + rtc_seg_ch.send_nowait( rtc.TranscriptionSegment( id=seg_id, text=text, @@ -362,23 +389,24 @@ async def _sync_sentence_co( ) ) self._played_text = f"{og_text} {text}" + await self._sleep_if_not_closed(delay - first_delay) - seg.processed_hyphens += word_hyphens + text_data.forwarded_hyphens += word_hyphens - rtc_seg_q.put_nowait( + rtc_seg_ch.send_nowait( rtc.TranscriptionSegment( id=seg_id, - text=tokenized_sentence, + text=sentence, start_time=0, end_time=0, final=True, language=self._opts.language, ) ) - self._played_text = f"{og_text} {tokenized_sentence}" + self._played_text = f"{og_text} {sentence}" await self._sleep_if_not_closed(self._opts.new_sentence_delay) - seg.processed_sentences += 1 + text_data.forwarded_sentences += 1 async def _sleep_if_not_closed(self, delay: float) -> None: with contextlib.suppress(asyncio.TimeoutError): @@ -393,14 +421,6 @@ def _calc_hyphens(self, text: str) -> list[str]: return hyphens - def _create_segment(self) -> _SegmentData: - data = _SegmentData( - segment_index=self._next_segment_index, - sentence_stream=self._opts.sentence_tokenizer.stream(), - ) - self._next_segment_index += 1 - return data - def _check_not_closed(self) -> None: if self._closed: raise RuntimeError("TTSForwarder is closed") diff --git a/livekit-agents/livekit/agents/utils/aio/__init__.py b/livekit-agents/livekit/agents/utils/aio/__init__.py index 803e12f73..2d0272011 100644 --- a/livekit-agents/livekit/agents/utils/aio/__init__.py +++ b/livekit-agents/livekit/agents/utils/aio/__init__.py @@ -1,7 +1,7 @@ import asyncio import contextlib -from . import debug, duplex_unix +from . import debug, duplex_unix, itertools from .channel import Chan, ChanClosed, ChanReceiver, ChanSender from .interval import Interval, interval from .sleep import Sleep, SleepFinished, sleep @@ -31,4 +31,5 @@ async def gracefully_cancel(*futures: asyncio.Future): "debug", "gracefully_cancel", "duplex_unix", + "itertools", ] diff --git a/livekit-agents/livekit/agents/utils/aio/itertools.py b/livekit-agents/livekit/agents/utils/aio/itertools.py new file mode 100644 index 000000000..0076f8eb5 --- /dev/null +++ b/livekit-agents/livekit/agents/utils/aio/itertools.py @@ -0,0 +1,114 @@ +import asyncio +from collections import deque +from typing import ( + Any, + AsyncGenerator, + AsyncIterable, + AsyncIterator, + Deque, + Generic, + Iterator, + List, + Protocol, + Tuple, + TypeVar, + Union, + overload, + runtime_checkable, +) + +from typing_extensions import AsyncContextManager + +# based on https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py + + +@runtime_checkable +class _ACloseable(Protocol): + async def aclose(self) -> None: + """Asynchronously close this object""" + + +T = TypeVar("T") + + +async def tee_peer( + iterator: AsyncIterator[T], + buffer: Deque[T], + peers: List[Deque[T]], + lock: AsyncContextManager[Any], +) -> AsyncGenerator[T, None]: + try: + while True: + if not buffer: + async with lock: + if buffer: + continue + try: + item = await iterator.__anext__() + except StopAsyncIteration: + break + else: + for peer_buffer in peers: + peer_buffer.append(item) + yield buffer.popleft() + finally: + for idx, peer_buffer in enumerate(peers): # pragma: no branch + if peer_buffer is buffer: + peers.pop(idx) + break + + if not peers and isinstance(iterator, _ACloseable): + await iterator.aclose() + + +class Tee(Generic[T]): + __slots__ = ("_iterator", "_buffers", "_children") + + def __init__( + self, + iterator: AsyncIterable[T], + n: int = 2, + ): + self._iterator = iterator.__aiter__() + self._buffers: List[Deque[T]] = [deque() for _ in range(n)] + + lock = asyncio.Lock() + self._children = tuple( + tee_peer( + iterator=self._iterator, + buffer=buffer, + peers=self._buffers, + lock=lock, + ) + for buffer in self._buffers + ) + + def __len__(self) -> int: + return len(self._children) + + @overload + def __getitem__(self, item: int) -> AsyncIterator[T]: ... + + @overload + def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]: ... + + def __getitem__( + self, item: Union[int, slice] + ) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]: + return self._children[item] + + def __iter__(self) -> Iterator[AsyncIterator[T]]: + yield from self._children + + async def __aenter__(self) -> "Tee[T]": + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + await self.aclose() + + async def aclose(self) -> None: + for child in self._children: + await child.aclose() + + +tee = Tee diff --git a/livekit-agents/livekit/agents/voice_assistant/agent_output.py b/livekit-agents/livekit/agents/voice_assistant/agent_output.py index 0317a2a82..5542ddb38 100644 --- a/livekit-agents/livekit/agents/voice_assistant/agent_output.py +++ b/livekit-agents/livekit/agents/voice_assistant/agent_output.py @@ -20,13 +20,21 @@ def __init__( self, *, speech_id: str, - speech_source: SpeechSource, + tts_source: SpeechSource, + transcript_source: SpeechSource, agent_playout: AgentPlayout, tts: text_to_speech.TTS, transcription_fwd: agent_transcription.TTSSegmentsForwarder, ) -> None: - self._speech_source, self._agent_playout, self._tts, self._tr_fwd = ( - speech_source, + ( + self._tts_source, + self._transcript_source, + self._agent_playout, + self._tts, + self._tr_fwd, + ) = ( + tts_source, + transcript_source, agent_playout, tts, transcription_fwd, @@ -113,7 +121,8 @@ def synthesize( self, *, speech_id: str, - transcript: SpeechSource, + tts_source: SpeechSource, + transcript_source: SpeechSource, transcription: bool, transcription_speed: float, sentence_tokenizer: tokenize.SentenceTokenizer, @@ -140,7 +149,8 @@ def _will_forward_transcription( ) handle = SynthesisHandle( - speech_source=transcript, + tts_source=tts_source, + transcript_source=transcript_source, agent_playout=self._agent_playout, tts=self._tts, transcription_fwd=transcription_fwd, @@ -155,15 +165,16 @@ def _will_forward_transcription( @utils.log_exceptions(logger=logger) async def _synthesize_task(self, handle: SynthesisHandle) -> None: """Synthesize speech from the source""" - speech_source = handle._speech_source - - if isinstance(speech_source, Awaitable): - speech_source = await speech_source - co = _str_synthesis_task(speech_source, handle) - elif isinstance(speech_source, str): - co = _str_synthesis_task(speech_source, handle) + tts_source = handle._tts_source + transcript_source = handle._transcript_source + + if isinstance(tts_source, Awaitable): + tts_source = await tts_source + co = _str_synthesis_task(tts_source, transcript_source, handle) + elif isinstance(tts_source, str): + co = _str_synthesis_task(tts_source, transcript_source, handle) else: - co = _stream_synthesis_task(speech_source, handle) + co = _stream_synthesis_task(tts_source, transcript_source, handle) synth = asyncio.create_task(co) synth.add_done_callback(lambda _: handle._buf_ch.close()) @@ -176,17 +187,19 @@ async def _synthesize_task(self, handle: SynthesisHandle) -> None: @utils.log_exceptions(logger=logger) -async def _str_synthesis_task(text: str, handle: SynthesisHandle) -> None: +async def _str_synthesis_task( + tts_text: str, transcript: str, handle: SynthesisHandle +) -> None: """synthesize speech from a string""" if not handle.tts_forwarder.closed: - handle.tts_forwarder.push_text(text) + handle.tts_forwarder.push_text(transcript) handle.tts_forwarder.mark_text_segment_end() start_time = time.time() first_frame = True try: - async for audio in handle._tts.synthesize(text): + async for audio in handle._tts.synthesize(tts_text): if first_frame: first_frame = False logger.debug( @@ -211,7 +224,9 @@ async def _str_synthesis_task(text: str, handle: SynthesisHandle) -> None: @utils.log_exceptions(logger=logger) async def _stream_synthesis_task( - streamed_text: AsyncIterable[str], handle: SynthesisHandle + tts_source: AsyncIterable[str], + transcript_source: AsyncIterable[str], + handle: SynthesisHandle, ) -> None: """synthesize speech from streamed text""" @@ -237,33 +252,41 @@ async def _read_generated_audio_task(): handle._buf_ch.send_nowait(audio.frame) if handle._tr_fwd and not handle._tr_fwd.closed: - # mark_audio_segment_end must be called *after* mart_text_segment_end handle._tr_fwd.mark_audio_segment_end() + @utils.log_exceptions(logger=logger) + async def _read_transcript_task(): + async for seg in transcript_source: + if not handle._tr_fwd.closed: + handle._tr_fwd.push_text(seg) + + if not handle.tts_forwarder.closed: + handle.tts_forwarder.mark_text_segment_end() + # otherwise, stream the text to the TTS tts_stream = handle._tts.stream() - read_atask: asyncio.Task | None = None + read_tts_atask: asyncio.Task | None = None + read_transcript_atask: asyncio.Task | None = None try: - async for seg in streamed_text: - if not handle.tts_forwarder.closed: - handle.tts_forwarder.push_text(seg) - - if read_atask is None: + async for seg in tts_source: + if read_tts_atask is None: # start the task when we receive the first text segment (so start_time is more accurate) - read_atask = asyncio.create_task(_read_generated_audio_task()) + read_tts_atask = asyncio.create_task(_read_generated_audio_task()) + read_transcript_atask = asyncio.create_task(_read_transcript_task()) tts_stream.push_text(seg) - if not handle.tts_forwarder.closed: - handle.tts_forwarder.mark_text_segment_end() - tts_stream.end_input() - if read_atask is not None: - await read_atask + if read_tts_atask is not None: + assert read_transcript_atask is not None + await read_tts_atask + await read_transcript_atask + finally: - if read_atask is not None: - await utils.aio.gracefully_cancel(read_atask) + if read_tts_atask is not None: + assert read_transcript_atask is not None + await utils.aio.gracefully_cancel(read_tts_atask, read_transcript_atask) await tts_stream.aclose() diff --git a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py index 52a61d827..ccf2a02d9 100644 --- a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py +++ b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py @@ -745,13 +745,19 @@ def _synthesize_agent_speech( if isinstance(source, LLMStream): source = _llm_stream_to_str_iterable(speech_id, source) - speech_source = self._opts.before_tts_cb(self, source) - if speech_source is None: - logger.error("pre_tts_generation_cb must return str or AsyncIterable[str]") + tts_source = source + transcript_source = source + if isinstance(tts_source, AsyncIterable): + tts_source, transcript_source = utils.aio.itertools.tee(tts_source, 2) + + tts_source = self._opts.before_tts_cb(self, tts_source) + if tts_source is None: + logger.error("before_tts_cb must return str or AsyncIterable[str]") return self._agent_output.synthesize( speech_id=speech_id, - transcript=speech_source, + tts_source=tts_source, + transcript_source=transcript_source, transcription=self._opts.transcription.agent_transcription, transcription_speed=self._opts.transcription.agent_transcription_speed, sentence_tokenizer=self._opts.transcription.sentence_tokenizer, From 4c75e2cbe4b5cbed2fb7a77e8de2e28c21febf87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Sat, 7 Sep 2024 23:02:12 -0700 Subject: [PATCH 10/13] Update voice_assistant.py --- .../livekit/agents/voice_assistant/voice_assistant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py index ccf2a02d9..e73d7984d 100644 --- a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py +++ b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py @@ -745,7 +745,7 @@ def _synthesize_agent_speech( if isinstance(source, LLMStream): source = _llm_stream_to_str_iterable(speech_id, source) - tts_source = source + tts_source: Union[AsyncIterable[str], str, Awaitable[str]] = source transcript_source = source if isinstance(tts_source, AsyncIterable): tts_source, transcript_source = utils.aio.itertools.tee(tts_source, 2) From 260a7b30f717aa071a36d494e55640b4504ab841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Sat, 7 Sep 2024 23:05:25 -0700 Subject: [PATCH 11/13] Update voice_assistant.py --- .../livekit/agents/voice_assistant/voice_assistant.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py index e73d7984d..fefc37634 100644 --- a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py +++ b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py @@ -745,12 +745,12 @@ def _synthesize_agent_speech( if isinstance(source, LLMStream): source = _llm_stream_to_str_iterable(speech_id, source) - tts_source: Union[AsyncIterable[str], str, Awaitable[str]] = source + og_source = source transcript_source = source - if isinstance(tts_source, AsyncIterable): - tts_source, transcript_source = utils.aio.itertools.tee(tts_source, 2) + if isinstance(og_source, AsyncIterable): + og_source, transcript_source = utils.aio.itertools.tee(og_source, 2) - tts_source = self._opts.before_tts_cb(self, tts_source) + tts_source = self._opts.before_tts_cb(self, og_source) if tts_source is None: logger.error("before_tts_cb must return str or AsyncIterable[str]") From 300abb9978967fca952044bb3bd5eb1d134cc5cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Sat, 7 Sep 2024 23:14:12 -0700 Subject: [PATCH 12/13] fix --- livekit-agents/livekit/agents/transcription/stt_forwarder.py | 5 +---- livekit-agents/livekit/agents/transcription/tts_forwarder.py | 5 +---- .../livekit/agents/voice_assistant/agent_output.py | 4 ++-- livekit-agents/livekit/agents/voice_assistant/human_input.py | 4 ++-- .../livekit/agents/voice_assistant/voice_assistant.py | 5 +---- 5 files changed, 7 insertions(+), 16 deletions(-) diff --git a/livekit-agents/livekit/agents/transcription/stt_forwarder.py b/livekit-agents/livekit/agents/transcription/stt_forwarder.py index 3ec7d79ec..0d526a3a6 100644 --- a/livekit-agents/livekit/agents/transcription/stt_forwarder.py +++ b/livekit-agents/livekit/agents/transcription/stt_forwarder.py @@ -2,7 +2,6 @@ import asyncio import contextlib -import warnings from typing import Awaitable, Callable, Optional, Union from livekit import rtc @@ -48,10 +47,8 @@ def __init__( track = track.sid if will_forward_transcription is not None: - warnings.warn( + logger.warning( "will_forward_transcription is deprecated and will be removed in 1.5.0, use before_forward_cb instead", - DeprecationWarning, - stacklevel=2, ) before_forward_cb = will_forward_transcription diff --git a/livekit-agents/livekit/agents/transcription/tts_forwarder.py b/livekit-agents/livekit/agents/transcription/tts_forwarder.py index 88f8d6d94..357354089 100644 --- a/livekit-agents/livekit/agents/transcription/tts_forwarder.py +++ b/livekit-agents/livekit/agents/transcription/tts_forwarder.py @@ -3,7 +3,6 @@ import asyncio import contextlib import time -import warnings from dataclasses import dataclass from typing import Awaitable, Callable, Optional, Union @@ -113,10 +112,8 @@ def __init__( track = track.sid if will_forward_transcription is not None: - warnings.warn( + logger.warning( "will_forward_transcription is deprecated and will be removed in 1.5.0, use before_forward_cb instead", - DeprecationWarning, - stacklevel=2, ) before_forward_cb = will_forward_transcription diff --git a/livekit-agents/livekit/agents/voice_assistant/agent_output.py b/livekit-agents/livekit/agents/voice_assistant/agent_output.py index 5542ddb38..f7747af6b 100644 --- a/livekit-agents/livekit/agents/voice_assistant/agent_output.py +++ b/livekit-agents/livekit/agents/voice_assistant/agent_output.py @@ -129,7 +129,7 @@ def synthesize( word_tokenizer: tokenize.WordTokenizer, hyphenate_word: Callable[[str], list[str]], ) -> SynthesisHandle: - def _will_forward_transcription( + def _before_forward( fwd: agent_transcription.TTSSegmentsForwarder, transcription: rtc.Transcription, ): @@ -145,7 +145,7 @@ def _will_forward_transcription( sentence_tokenizer=sentence_tokenizer, word_tokenizer=word_tokenizer, hyphenate_word=hyphenate_word, - will_forward_transcription=_will_forward_transcription, + before_forward_cb=_before_forward, ) handle = SynthesisHandle( diff --git a/livekit-agents/livekit/agents/voice_assistant/human_input.py b/livekit-agents/livekit/agents/voice_assistant/human_input.py index a3ddc5248..22fec121e 100644 --- a/livekit-agents/livekit/agents/voice_assistant/human_input.py +++ b/livekit-agents/livekit/agents/voice_assistant/human_input.py @@ -101,7 +101,7 @@ async def _recognize_task(self, audio_stream: rtc.AudioStream) -> None: vad_stream = self._vad.stream() stt_stream = self._stt.stream() - def _will_forward_transcription( + def _before_forward( fwd: transcription.STTSegmentsForwarder, transcription: rtc.Transcription ): if not self._transcription: @@ -113,7 +113,7 @@ def _will_forward_transcription( room=self._room, participant=self._participant, track=self._subscribed_track, - will_forward_transcription=_will_forward_transcription, + before_forward_cb=_before_forward, ) async def _audio_stream_co() -> None: diff --git a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py index fefc37634..59a4431c8 100644 --- a/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py +++ b/livekit-agents/livekit/agents/voice_assistant/voice_assistant.py @@ -3,7 +3,6 @@ import asyncio import contextvars import time -import warnings from dataclasses import dataclass from typing import Any, AsyncIterable, Awaitable, Callable, Literal, Optional, Union @@ -174,10 +173,8 @@ def __init__( self._loop = loop or asyncio.get_event_loop() if will_synthesize_assistant_reply is not None: - warnings.warn( + logger.warning( "will_synthesize_assistant_reply is deprecated and will be removed in 1.5.0, use before_llm_cb instead", - DeprecationWarning, - stacklevel=2, ) before_llm_cb = will_synthesize_assistant_reply From 41e64d2fce57cf42d482d06f26805bfbea7ad07f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Sat, 7 Sep 2024 23:16:07 -0700 Subject: [PATCH 13/13] Update test_tokenizer.py --- tests/test_tokenizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 7b1d79141..b4b554e60 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -148,7 +148,7 @@ def test_hyphenate_word(): "framework. A.B.C" ) REPLACE_EXPECTED = ( - "This is a test. Hello universe, I'm creating this agents.. library. Twice again " + "This is a test. Hello universe, I'm creating this assistants.. library. Twice again " "library. A.B.C.D" ) @@ -157,6 +157,7 @@ def test_hyphenate_word(): "framework": "library", "a.b.c": "A.B.C.D", "once": "twice", + "agents": "assistants", }