Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

voiceassistant: add before_tts_cb callback #706

Merged
merged 15 commits into from
Sep 8, 2024
5 changes: 5 additions & 0 deletions .changeset/three-onions-destroy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

voiceassistant: add will_synthesize_assistant_speech
49 changes: 49 additions & 0 deletions examples/voice-assistant/custom_pronunciation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from typing import AsyncIterable

from dotenv import load_dotenv
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm, tokenize
from livekit.agents.voice_assistant import VoiceAssistant
from livekit.plugins import cartesia, deepgram, openai, silero

load_dotenv()


async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)

await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)

def _before_tts_cb(assistant: VoiceAssistant, text: str | AsyncIterable[str]):
# The TTS is incorrectly pronouncing "LiveKit", so we'll replace it with a phonetic
# spelling
return tokenize.utils.replace_words(
text=text, replacements={"livekit": r"<<l|aɪ|v|k|ɪ|t|>>"}
)

# also for this example, we also intensify the keyword "LiveKit" to make it more likely to be
# recognized with the STT
deepgram_stt = deepgram.STT(keywords=[("LiveKit", 3.5)])

assistant = VoiceAssistant(
vad=silero.VAD.load(),
stt=deepgram_stt,
llm=openai.LLM(),
tts=cartesia.TTS(),
chat_ctx=initial_ctx,
before_tts_cb=_before_tts_cb,
)
assistant.start(ctx.room)

await assistant.say("Hey, LiveKit is awesome!", allow_interruptions=True)


if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
3 changes: 2 additions & 1 deletion livekit-agents/livekit/agents/tokenize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import basic
from . import basic, utils
from .token_stream import (
BufferedSentenceStream,
BufferedWordStream,
Expand All @@ -20,4 +20,5 @@
"BufferedSentenceStream",
"BufferedWordStream",
"basic",
"utils",
]
90 changes: 90 additions & 0 deletions livekit-agents/livekit/agents/tokenize/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

from typing import AsyncIterable, overload

from . import _basic_word, tokenizer


@overload
def replace_words(
*,
text: str,
replacements: dict[str, str],
) -> str: ...


@overload
def replace_words(
*,
text: AsyncIterable[str],
replacements: dict[str, str],
) -> AsyncIterable[str]: ...


def replace_words(
*,
text: str | AsyncIterable[str],
replacements: dict[str, str],
) -> str | AsyncIterable[str]:
"""
Replace words in the given (async) text. The replacements are case-insensitive and the
replacement will keep the case of the original word.
Args:
text: text to replace words in
words: dictionary of words to replace
"""

replacements = {k.lower(): v for k, v in replacements.items()}

def _match_case(word, replacement):
if word.isupper():
return replacement.upper()
elif word.istitle():
return replacement.title()
else:
return replacement.lower()

def _process_words(text, words):
offset = 0
processed_index = 0
for word, start_index, end_index in words:
no_punctuation = word.rstrip("".join(tokenizer.PUNCTUATIONS))
punctuation_off = len(word) - len(no_punctuation)
replacement = replacements.get(no_punctuation.lower())
if replacement:
text = (
text[: start_index + offset]
+ _match_case(word, replacement)
+ text[end_index + offset - punctuation_off :]
)
offset += len(replacement) - len(word) + punctuation_off

processed_index = end_index + offset

return text, processed_index

if isinstance(text, str):
words = _basic_word.split_words(text, ignore_punctuation=False)
text, _ = _process_words(text, words)
return text
else:

async def _replace_words():
buffer = ""
async for chunk in text:
buffer += chunk
words = _basic_word.split_words(buffer, ignore_punctuation=False)

if len(words) <= 1:
continue

buffer, procesed_index = _process_words(buffer, words[:-1])
yield buffer[:procesed_index]
buffer = buffer[procesed_index:]

if buffer:
words = _basic_word.split_words(buffer, ignore_punctuation=False)
buffer, _ = _process_words(buffer, words)
yield buffer

return _replace_words()
27 changes: 17 additions & 10 deletions livekit-agents/livekit/agents/transcription/stt_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
from ..log import logger
from . import _utils

WillForwardTranscription = Callable[
BeforeForwardCallback = Callable[
["STTSegmentsForwarder", rtc.Transcription],
Union[rtc.Transcription, Awaitable[Optional[rtc.Transcription]]],
]


def _default_will_forward_transcription(
WillForwardTranscription = BeforeForwardCallback


def _default_before_forward_cb(
fwd: STTSegmentsForwarder, transcription: rtc.Transcription
) -> rtc.Transcription:
return transcription
Expand All @@ -33,16 +36,24 @@ def __init__(
room: rtc.Room,
participant: rtc.Participant | str,
track: rtc.Track | rtc.TrackPublication | str | None = None,
will_forward_transcription: WillForwardTranscription = _default_will_forward_transcription,
before_forward_cb: BeforeForwardCallback = _default_before_forward_cb,
# backward compatibility
will_forward_transcription: WillForwardTranscription | None = None,
):
identity = participant if isinstance(participant, str) else participant.identity
if track is None:
track = _utils.find_micro_track_id(room, identity)
elif isinstance(track, (rtc.TrackPublication, rtc.Track)):
track = track.sid

if will_forward_transcription is not None:
logger.warning(
"will_forward_transcription is deprecated and will be removed in 1.5.0, use before_forward_cb instead",
)
before_forward_cb = will_forward_transcription

self._room, self._participant_identity, self._track_id = room, identity, track
self._will_forward_transcription = will_forward_transcription
self._before_forward_cb = before_forward_cb
self._queue = asyncio.Queue[Optional[rtc.TranscriptionSegment]]()
self._main_task = asyncio.create_task(self._run())
self._current_id = _utils.segment_uuid()
Expand All @@ -60,16 +71,12 @@ async def _run(self):
segments=[seg], # no history for now
)

transcription = self._will_forward_transcription(
self, base_transcription
)
transcription = self._before_forward_cb(self, base_transcription)
if asyncio.iscoroutine(transcription):
transcription = await transcription

if not isinstance(transcription, rtc.Transcription):
transcription = _default_will_forward_transcription(
self, base_transcription
)
transcription = _default_before_forward_cb(self, base_transcription)

if transcription.segments and self._room.isconnected():
await self._room.local_participant.publish_transcription(
Expand Down
Loading
Loading