Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

voiceassistant: add before_tts_cb callback #706

Merged
merged 15 commits into from
Sep 8, 2024
5 changes: 5 additions & 0 deletions .changeset/three-onions-destroy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

voiceassistant: add will_synthesize_assistant_speech
52 changes: 52 additions & 0 deletions examples/voice-assistant/custom_pronunciation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations
from typing import AsyncIterable
import asyncio

Check failure on line 3 in examples/voice-assistant/custom_pronunciation.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

examples/voice-assistant/custom_pronunciation.py:3:8: F401 `asyncio` imported but unused

from dotenv import load_dotenv
from livekit import rtc

Check failure on line 6 in examples/voice-assistant/custom_pronunciation.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

examples/voice-assistant/custom_pronunciation.py:6:21: F401 `livekit.rtc` imported but unused
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm, utils
from livekit.agents.voice_assistant import VoiceAssistant
from livekit.plugins import deepgram, cartesia, openai, silero

load_dotenv()

Check failure on line 11 in examples/voice-assistant/custom_pronunciation.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

examples/voice-assistant/custom_pronunciation.py:1:1: I001 Import block is un-sorted or un-formatted


async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)

await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)

async def _will_synthesize_assistant_speech(
assistant: VoiceAssistant, text: str | AsyncIterable[str]
):
# Cartesia TTS is incorrectly pronouncing "LiveKit", so we'll replace it with a phonetic
theomonnom marked this conversation as resolved.
Show resolved Hide resolved
# spelling
return utils.replace_words(
text=text, replacements={"LiveKit": "<<L|ˈaɪ|ve>>Kit"}
)

# also for this example, we also intensify the keyword "LiveKit" to make it more likely to be
# recognized with the STT
deepgram_stt = deepgram.STT(keywords=[("LiveKit", 2.0)])

assistant = VoiceAssistant(
vad=silero.VAD.load(),
stt=deepgram_stt,
llm=openai.LLM(),
tts=cartesia.TTS(),
chat_ctx=initial_ctx,
will_synthesize_assistant_speech=_will_synthesize_assistant_speech,
)
assistant.start(ctx.room)

await assistant.say("Hey, how can I help you today?", allow_interruptions=True)


if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
3 changes: 2 additions & 1 deletion livekit-agents/livekit/agents/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from . import aio, audio, codecs, http_context, images
from .event_emitter import EventEmitter
from .exp_filter import ExpFilter
from .log import log_exceptions
from .misc import AudioBuffer, merge_frames, shortuuid, time_ms
from .misc import AudioBuffer, merge_frames, shortuuid, time_ms, replace_words
from .moving_average import MovingAverage

__all__ = [

Check failure on line 8 in livekit-agents/livekit/agents/utils/__init__.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

livekit-agents/livekit/agents/utils/__init__.py:1:1: I001 Import block is un-sorted or un-formatted
"AudioBuffer",
"merge_frames",
"time_ms",
"shortuuid",
"replace_words",
"http_context",
"ExpFilter",
"MovingAverage",
Expand Down
55 changes: 54 additions & 1 deletion livekit-agents/livekit/agents/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import time
import uuid
from typing import List, Union
from typing import AsyncIterable, List, Union, overload

from livekit import rtc

Expand Down Expand Up @@ -48,3 +50,54 @@ def time_ms() -> int:

def shortuuid() -> str:
return str(uuid.uuid4().hex)[:12]


@overload
def replace_words(
*,
text: str,
replacements: dict[str, str],
) -> str: ...


@overload
def replace_words(
*,
text: AsyncIterable[str],
replacements: dict[str, str],
) -> AsyncIterable[str]: ...


def replace_words(
*,
text: str | AsyncIterable[str],
replacements: dict[str, str],
) -> str | AsyncIterable[str]:
"""
Replace words in text with another str
Args:
text: text to replace words in
words: dictionary of words to replace
"""
if isinstance(text, str):
for word, replacement in replacements.items():
text = text.replace(word, replacement)
return text
else:

async def _replace_words():
buffer = ""
async for chunk in text:
for char in chunk:
if not char.isspace():
buffer += char
else:
if buffer:
yield replacements.get(buffer, buffer)
buffer = ""
yield char

if buffer:
yield replacements.get(buffer, buffer)

return _replace_words()
15 changes: 10 additions & 5 deletions livekit-agents/livekit/agents/voice_assistant/agent_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import time
from typing import Any, AsyncIterable, Callable, Union
from typing import Any, AsyncIterable, Awaitable, Callable, Union

from livekit import rtc

Expand All @@ -12,7 +12,7 @@
from .agent_playout import AgentPlayout, PlayoutHandle
from .log import logger

SpeechSource = Union[AsyncIterable[str], str]
SpeechSource = Union[AsyncIterable[str], str, Awaitable[str]]


class SynthesisHandle:
Expand Down Expand Up @@ -155,10 +155,15 @@ def _will_forward_transcription(
@utils.log_exceptions(logger=logger)
async def _synthesize_task(self, handle: SynthesisHandle) -> None:
"""Synthesize speech from the source"""
if isinstance(handle._speech_source, str):
co = _str_synthesis_task(handle._speech_source, handle)
speech_source = handle._speech_source

if isinstance(speech_source, Awaitable):
speech_source = await speech_source
co = _str_synthesis_task(speech_source, handle)
elif isinstance(speech_source, str):
co = _str_synthesis_task(speech_source, handle)
else:
co = _stream_synthesis_task(handle._speech_source, handle)
co = _stream_synthesis_task(speech_source, handle)

synth = asyncio.create_task(co)
synth.add_done_callback(lambda _: handle._buf_ch.close())
Expand Down
57 changes: 46 additions & 11 deletions livekit-agents/livekit/agents/voice_assistant/voice_assistant.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,37 @@
from __future__ import annotations

import warnings
import asyncio
import contextvars
import time
from dataclasses import dataclass
from typing import Any, AsyncIterable, Awaitable, Callable, Literal, Optional, Union

from livekit import rtc

from .. import stt, tokenize, tts, utils, vad
from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream
from .agent_output import AgentOutput, SynthesisHandle
from .agent_playout import AgentPlayout
from .human_input import HumanInput
from .log import logger
from .plotter import AssistantPlotter
from .speech_handle import SpeechHandle

WillSynthesizeAssistantReply = Callable[

PreLLMGenerationCallback = Callable[

Check failure on line 22 in livekit-agents/livekit/agents/voice_assistant/voice_assistant.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

livekit-agents/livekit/agents/voice_assistant/voice_assistant.py:1:1: I001 Import block is un-sorted or un-formatted
["VoiceAssistant", ChatContext],
Union[Optional[LLMStream], Awaitable[Optional[LLMStream]]],
]

WillSynthesizeAssistantReply = PreLLMGenerationCallback

PreTTSGenerationCallback = Callable[
theomonnom marked this conversation as resolved.
Show resolved Hide resolved
["VoiceAssistant", Union[str, AsyncIterable[str]]],
Union[str, AsyncIterable[str], Awaitable[str]],
]


EventTypes = Literal[
"user_started_speaking",
"user_stopped_speaking",
Expand Down Expand Up @@ -66,19 +76,26 @@
return self._llm_stream


def _default_will_synthesize_assistant_reply(
def _default_llm_generation_cb(
assistant: VoiceAssistant, chat_ctx: ChatContext
) -> LLMStream:
return assistant.llm.chat(chat_ctx=chat_ctx, fnc_ctx=assistant.fnc_ctx)


def _default_tts_generation_cb(
assistant: VoiceAssistant, text: str | AsyncIterable[str]
) -> str | AsyncIterable[str]:
return text


@dataclass(frozen=True)
class _ImplOptions:
allow_interruptions: bool
int_speech_duration: float
int_min_words: int
preemptive_synthesis: bool
will_synthesize_assistant_reply: WillSynthesizeAssistantReply
pre_llm_generation_cb: PreLLMGenerationCallback
pre_tts_generation_cb: PreTTSGenerationCallback
plotting: bool
transcription: AssistantTranscriptionOptions

Expand Down Expand Up @@ -123,9 +140,12 @@
interrupt_min_words: int = 0,
preemptive_synthesis: bool = True,
transcription: AssistantTranscriptionOptions = AssistantTranscriptionOptions(),
will_synthesize_assistant_reply: WillSynthesizeAssistantReply = _default_will_synthesize_assistant_reply,
pre_llm_generation_cb: PreLLMGenerationCallback = _default_llm_generation_cb,
pre_tts_generation_cb: PreTTSGenerationCallback = _default_tts_generation_cb,
plotting: bool = False,
loop: asyncio.AbstractEventLoop | None = None,
# backward compatibility
will_synthesize_assistant_reply: PreLLMGenerationCallback | None = None,
) -> None:
"""
Create a new VoiceAssistant.
Expand All @@ -143,21 +163,34 @@
Defaults to 0 as this may increase the latency depending on the STT.
preemptive_synthesis: Whether to preemptively synthesize responses.
transcription: Options for assistant transcription.
will_synthesize_assistant_reply: Callback called when the assistant is about to synthesize a reply.
pre_llm_generation_cb: Callback called when the assistant is about to synthesize a reply.
This can be used to customize the reply (e.g: inject context/RAG).
pre_tts_generation_cb: Callback called when the assistant is about to
synthesize a speech. This can be used to customize text before the speech synthesis.
(e.g: editing the pronunciation of a word).
plotting: Whether to enable plotting for debugging. matplotlib must be installed.
loop: Event loop to use. Default to asyncio.get_event_loop().
"""
super().__init__()
self._loop = loop or asyncio.get_event_loop()

if will_synthesize_assistant_reply is not None:
warnings.warn(
"will_synthesize_assistant_reply is deprecated and will be removed in 1.5.0, use on_generate_llm instead",
theomonnom marked this conversation as resolved.
Show resolved Hide resolved
DeprecationWarning,
stacklevel=2,
)
on_generate_llm = will_synthesize_assistant_reply

Check failure on line 183 in livekit-agents/livekit/agents/voice_assistant/voice_assistant.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F841)

livekit-agents/livekit/agents/voice_assistant/voice_assistant.py:183:13: F841 Local variable `on_generate_llm` is assigned to but never used

self._opts = _ImplOptions(
plotting=plotting,
allow_interruptions=allow_interruptions,
int_speech_duration=interrupt_speech_duration,
int_min_words=interrupt_min_words,
preemptive_synthesis=preemptive_synthesis,
transcription=transcription,
will_synthesize_assistant_reply=will_synthesize_assistant_reply,
pre_llm_generation_cb=pre_llm_generation_cb,
pre_tts_generation_cb=pre_tts_generation_cb,
)
self._plotter = AssistantPlotter(self._loop)

Expand Down Expand Up @@ -505,15 +538,13 @@
ChatMessage.create(text=handle.user_question, role="user")
)

llm_stream = self._opts.will_synthesize_assistant_reply(self, copied_ctx)
llm_stream = self._opts.pre_llm_generation_cb(self, copied_ctx)
if asyncio.iscoroutine(llm_stream):
llm_stream = await llm_stream

# fallback to default impl if no custom/user stream is returned
if not isinstance(llm_stream, LLMStream):
llm_stream = _default_will_synthesize_assistant_reply(
self, chat_ctx=copied_ctx
)
llm_stream = _default_llm_generation_cb(self, chat_ctx=copied_ctx)

synthesis_handle = self._synthesize_agent_speech(handle.id, llm_stream)
handle.initialize(source=llm_stream, synthesis_handle=synthesis_handle)
Expand Down Expand Up @@ -712,9 +743,13 @@
if isinstance(source, LLMStream):
source = _llm_stream_to_str_iterable(speech_id, source)

speech_source = self._opts.pre_tts_generation_cb(self, source)
if speech_source is None:
logger.error("pre_tts_generation_cb must return str or AsyncIterable[str]")

return self._agent_output.synthesize(
speech_id=speech_id,
transcript=source,
transcript=speech_source,
transcription=self._opts.transcription.agent_transcription,
transcription_speed=self._opts.transcription.agent_transcription_speed,
sentence_tokenizer=self._opts.transcription.sentence_tokenizer,
Expand Down
31 changes: 30 additions & 1 deletion tests/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from livekit.agents import tokenize
from livekit.agents import tokenize, utils
from livekit.agents.tokenize import basic
from livekit.plugins import nltk

Expand Down Expand Up @@ -141,3 +141,32 @@ def test_hyphenate_word():
for i, word in enumerate(HYPHENATOR_TEXT):
hyphenated = basic.hyphenate_word(word)
assert hyphenated == HYPHENATOR_EXPECTED[i]


REPLACE_TEXT = "This is a test. Hello world, I'm creating this agents.. framework"
REPLACE_EXPECTED = (
"This is a test. Hello universe, I'm creating this agents.. library"
)

REPLACE_REPLACEMENTS = {
"world": "universe",
"framework": "library",
}


def test_replace_words():
replaced = utils.replace_words(text=REPLACE_TEXT, replacements=REPLACE_REPLACEMENTS)
assert replaced == REPLACE_EXPECTED


async def text_replace_words_async():
pattern = [1, 2, 4]
text = REPLACE_TEXT
chunks = []
pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1))

for chunk_size in pattern_iter:
if not text:
break
chunks.append(text[:chunk_size])
text = text[chunk_size:]
Loading