diff --git a/.changeset/strange-apes-hide.md b/.changeset/strange-apes-hide.md new file mode 100644 index 000000000..c43f064f3 --- /dev/null +++ b/.changeset/strange-apes-hide.md @@ -0,0 +1,8 @@ +--- +"livekit-plugins-azure": minor +"livekit-plugins-turn-detector": patch +"livekit-plugins-openai": patch +"livekit-agents": patch +--- + +Improvements to end of turn plugin, ensure STT language settings. diff --git a/examples/voice-pipeline-agent/README.md b/examples/voice-pipeline-agent/README.md index 6f7e176fb..a8fb69410 100644 --- a/examples/voice-pipeline-agent/README.md +++ b/examples/voice-pipeline-agent/README.md @@ -34,7 +34,10 @@ export OPENAI_API_KEY= ### Install requirments: -`pip install -r requirements.txt` +``` +pip install -r requirements.txt +python minimal_assistant.py download-files +``` ### Run the agent worker: diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index c0ff31a3c..f3df2d60a 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -160,8 +160,12 @@ class AgentTranscriptionOptions: representing the hyphenated parts of the word.""" -class _EOUModel(Protocol): - async def predict_eou(self, chat_ctx: ChatContext) -> float: ... +class _TurnDetector(Protocol): + # When endpoint probability is below this threshold we think the user is not finished speaking + # so we will use a long delay + def unlikely_threshold(self) -> float: ... + def supports_language(self, language: str | None) -> bool: ... + async def predict_end_of_turn(self, chat_ctx: ChatContext) -> float: ... class VoicePipelineAgent(utils.EventEmitter[EventTypes]): @@ -179,7 +183,7 @@ def __init__( stt: stt.STT, llm: LLM, tts: tts.TTS, - turn_detector: _EOUModel | None = None, + turn_detector: _TurnDetector | None = None, chat_ctx: ChatContext | None = None, fnc_ctx: FunctionContext | None = None, allow_interruptions: bool = True, @@ -595,7 +599,9 @@ def _on_final_transcript(ev: stt.SpeechEvent) -> None: ): self._synthesize_agent_reply() - self._deferred_validation.on_human_final_transcript(new_transcript) + self._deferred_validation.on_human_final_transcript( + new_transcript, ev.alternatives[0].language + ) words = self._opts.transcription.word_tokenizer.tokenize( text=new_transcript @@ -1105,24 +1111,22 @@ class _DeferredReplyValidation: LATE_TRANSCRIPT_TOLERANCE = 1.5 # late compared to end of speech - # When endpoint probability is below this threshold we think the user is not finished speaking - # so we will use a long delay - UNLIKELY_ENDPOINT_THRESHOLD = 0.15 - # Long delay to use when the model thinks the user is still speaking + # TODO: make this configurable UNLIKELY_ENDPOINT_DELAY = 6 def __init__( self, validate_fnc: Callable[[], None], min_endpointing_delay: float, - turn_detector: _EOUModel | None, + turn_detector: _TurnDetector | None, agent: VoicePipelineAgent, ) -> None: self._turn_detector = turn_detector self._validate_fnc = validate_fnc self._validating_task: asyncio.Task | None = None self._last_final_transcript: str = "" + self._last_language: str | None = None self._last_recv_end_of_speech_time: float = 0.0 self._speaking = False @@ -1134,8 +1138,9 @@ def __init__( def validating(self) -> bool: return self._validating_task is not None and not self._validating_task.done() - def on_human_final_transcript(self, transcript: str) -> None: + def on_human_final_transcript(self, transcript: str, language: str | None) -> None: self._last_final_transcript += " " + transcript.strip() # type: ignore + self._last_language = language if self._speaking: return @@ -1193,9 +1198,13 @@ def _run(self, delay: float) -> None: @utils.log_exceptions(logger=logger) async def _run_task(chat_ctx: ChatContext, delay: float) -> None: await asyncio.sleep(delay) - if self._turn_detector is not None: - eou_prob = await self._turn_detector.predict_eou(chat_ctx) - if eou_prob < self.UNLIKELY_ENDPOINT_THRESHOLD: + if ( + self._turn_detector is not None + and self._turn_detector.supports_language(self._last_language) + ): + eot_prob = await self._turn_detector.predict_end_of_turn(chat_ctx) + unlikely_threshold = self._turn_detector.unlikely_threshold() + if eot_prob < unlikely_threshold: await asyncio.sleep(self.UNLIKELY_ENDPOINT_DELAY) self._reset_states() diff --git a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py index ac68760b6..d705a7f2c 100644 --- a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py +++ b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py @@ -16,6 +16,7 @@ import contextlib import os import weakref +from copy import deepcopy from dataclasses import dataclass from livekit import rtc @@ -55,7 +56,10 @@ def __init__( segmentation_silence_timeout_ms: int | None = None, segmentation_max_time_ms: int | None = None, segmentation_strategy: str | None = None, - languages: list[str] = [], # when empty, auto-detect the language + # Azure handles multiple languages and can auto-detect the language used. It requires the candidate set to be set. + languages: list[str] = ["en-US"], + # for compatibility with other STT plugins + language: str | None = None, ): """ Create a new instance of Azure STT. @@ -83,6 +87,9 @@ def __init__( "AZURE_SPEECH_HOST or AZURE_SPEECH_KEY and AZURE_SPEECH_REGION or speech_auth_token and AZURE_SPEECH_REGION must be set" ) + if language: + languages = [language] + self._config = STTOptions( speech_key=speech_key, speech_region=speech_region, @@ -109,18 +116,28 @@ async def _recognize_impl( def stream( self, *, + languages: list[str] | None = None, language: str | None = None, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> "SpeechStream": - stream = SpeechStream(stt=self, opts=self._config, conn_options=conn_options) + config = deepcopy(self._config) + if language and not languages: + languages = [language] + if languages: + config.languages = languages + stream = SpeechStream(stt=self, opts=config, conn_options=conn_options) self._streams.add(stream) return stream - def update_options(self, *, language: str | None = None): - if language is not None: - self._config.languages = [language] + def update_options( + self, *, language: str | None = None, languages: list[str] | None = None + ): + if language and not languages: + languages = [language] + if languages is not None: + self._config.languages = languages for stream in self._streams: - stream.update_options(language=language) + stream.update_options(languages=languages) class SpeechStream(stt.SpeechStream): @@ -139,9 +156,13 @@ def __init__( self._loop = asyncio.get_running_loop() self._reconnect_event = asyncio.Event() - def update_options(self, *, language: str | None = None): - if language: - self._opts.languages = [language] + def update_options( + self, *, language: str | None = None, languages: list[str] | None = None + ): + if language and not languages: + languages = [language] + if languages: + self._opts.languages = languages self._reconnect_event.set() async def _run(self) -> None: @@ -206,6 +227,9 @@ def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs): if not text: return + if not detected_lg and self._opts.languages: + detected_lg = self._opts.languages[0] + final_data = stt.SpeechData( language=detected_lg, confidence=1.0, text=evt.result.text ) @@ -224,6 +248,9 @@ def _on_recognizing(self, evt: speechsdk.SpeechRecognitionEventArgs): if not text: return + if not detected_lg and self._opts.languages: + detected_lg = self._opts.languages[0] + interim_data = stt.SpeechData( language=detected_lg, confidence=0.0, text=evt.result.text ) @@ -303,7 +330,7 @@ def _create_speech_recognizer( ) auto_detect_source_language_config = None - if config.languages: + if config.languages and len(config.languages) > 1: auto_detect_source_language_config = ( speechsdk.languageconfig.AutoDetectSourceLanguageConfig( languages=config.languages diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py index 9c0387fb5..e3f19972a 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py @@ -148,14 +148,18 @@ async def _recognize_impl( ), model=self._opts.model, language=config.language, - response_format="json", + # verbose_json returns language and other details + response_format="verbose_json", timeout=httpx.Timeout(30, connect=conn_options.timeout), ) return stt.SpeechEvent( type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[ - stt.SpeechData(text=resp.text or "", language=language or "") + stt.SpeechData( + text=resp.text or "", + language=resp.language or config.language or "", + ) ], ) diff --git a/livekit-plugins/livekit-plugins-turn-detector/README.md b/livekit-plugins/livekit-plugins-turn-detector/README.md index eee8d00c3..988706784 100644 --- a/livekit-plugins/livekit-plugins-turn-detector/README.md +++ b/livekit-plugins/livekit-plugins-turn-detector/README.md @@ -1,2 +1,48 @@ # LiveKit Plugins Turn Detector +This plugin introduces end-of-turn detection for LiveKit Agents using a custom open-weight model to determine when a user has finished speaking. + +Traditional voice agents use VAD (voice activity detection) for end-of-turn detection. However, VAD models lack language understanding, often causing false positives where the agent interrupts the user before they finish speaking. + +By leveraging a language model specifically trained for this task, this plugin offers a more accurate and robust method for detecting end-of-turns. The current version supports English only and should not be used when targeting other languages. + +## Installation + +```bash +pip install livekit-plugins-turn-detector +``` + +## Usage + +This plugin is designed to be used with the `VoicePipelineAgent`: + +```python +from livekit.plugins import turn_detector + +agent = VoicePipelineAgent( + ... + turn_detector=turn_detector.EOUModel(), +) +``` + +## Running your agent + +This plugin requires model files. Before starting your agent for the first time, or when building Docker images for deployment, run the following command to download the model files: + +```bash +python my_agent.py download-files +``` + +## Model system requirements + +The end-of-turn model is optimized to run on CPUs with modest system requirements. It is designed to run on the same server hosting your agents. On a 4-core server instance, it completes inference in under 100ms with minimal CPU usage. + +The model requires 1.5GB of RAM and runs within a shared inference server, supporting multiple concurrent sessions. + +We are working to reduce the CPU and memory requirements in future releases. + +## License + +The plugin source code is licensed under the Apache-2.0 license. + +The end-of-turn model is licensed under the [LiveKit Model License](https://huggingface.co/livekit/turn-detector/blob/main/LICENSE). diff --git a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py index 6ef44c6eb..d5f21799e 100644 --- a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py +++ b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py @@ -12,7 +12,7 @@ from .log import logger -HG_MODEL = "livekit/opt-125m-endpoint-detector-2" +HG_MODEL = "livekit/turn-detector" PUNCS = string.punctuation.replace("'", "") MAX_HISTORY = 4 @@ -113,12 +113,30 @@ def run(self, data: bytes) -> bytes | None: class EOUModel: - def __init__(self, inference_executor: InferenceExecutor | None = None) -> None: + def __init__( + self, + inference_executor: InferenceExecutor | None = None, + unlikely_threshold: float = 0.15, + ) -> None: self._executor = ( inference_executor or get_current_job_context().inference_executor ) + self._unlikely_threshold = unlikely_threshold + + def unlikely_threshold(self) -> float: + return self._unlikely_threshold + + def supports_language(self, language: str | None) -> bool: + if language is None: + return False + parts = language.lower().split("-") + # certain models use language codes (DG, AssemblyAI), others use full names (like OAI) + return parts[0] == "en" or parts[0] == "english" async def predict_eou(self, chat_ctx: llm.ChatContext) -> float: + return await self.predict_end_of_turn(chat_ctx) + + async def predict_end_of_turn(self, chat_ctx: llm.ChatContext) -> float: messages = [] for msg in chat_ctx.messages: diff --git a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/log.py b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/log.py index 11cb57b75..2b29634ad 100644 --- a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/log.py +++ b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/log.py @@ -1,3 +1,3 @@ import logging -logger = logging.getLogger("livekit.plugins.eou") +logger = logging.getLogger("livekit.plugins.turn_detector") diff --git a/livekit-plugins/livekit-plugins-turn-detector/setup.py b/livekit-plugins/livekit-plugins-turn-detector/setup.py index 00f7ed255..1bd17caa1 100644 --- a/livekit-plugins/livekit-plugins-turn-detector/setup.py +++ b/livekit-plugins/livekit-plugins-turn-detector/setup.py @@ -49,8 +49,13 @@ license="Apache-2.0", packages=setuptools.find_namespace_packages(include=["livekit.*"]), python_requires=">=3.9.0", - install_requires=["livekit-agents>=0.11", "transformers>=4.46", "numpy>=1.26"], - package_data={"livekit.plugins.eou": ["py.typed"]}, + install_requires=[ + "livekit-agents>=0.11", + "transformers>=4.46", + "numpy>=1.26", + "torch>=2.0", + ], + package_data={"livekit.plugins.turn_detector": ["py.typed"]}, project_urls={ "Documentation": "https://docs.livekit.io", "Website": "https://livekit.io/", diff --git a/tests/test_stt.py b/tests/test_stt.py index f9e52b3d8..836cfd20a 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -109,6 +109,10 @@ async def _stream_output(): if event.type == agents.stt.SpeechEventType.FINAL_TRANSCRIPT: text += event.alternatives[0].text + # ensure STT is tagging languages correctly + language = event.alternatives[0].language + assert language is not None + assert language.lower().startswith("en") if event.type == agents.stt.SpeechEventType.END_OF_SPEECH: recv_start = False