From d3dc60855908fa5f4daba4a40a5fac49b60a1e84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Sun, 8 Sep 2024 12:32:49 -0700 Subject: [PATCH] add JobContext.wait_for_participant (#712) --- .changeset/shy-chefs-boil.md | 5 +++ livekit-agents/livekit/agents/job.py | 65 +++++++++++++++++++++------- 2 files changed, 54 insertions(+), 16 deletions(-) create mode 100644 .changeset/shy-chefs-boil.md diff --git a/.changeset/shy-chefs-boil.md b/.changeset/shy-chefs-boil.md new file mode 100644 index 000000000..a9a6f81b3 --- /dev/null +++ b/.changeset/shy-chefs-boil.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": patch +--- + +add JobContext.wait_for_participant diff --git a/livekit-agents/livekit/agents/job.py b/livekit-agents/livekit/agents/job.py index f71d860e4..19574b71b 100644 --- a/livekit-agents/livekit/agents/job.py +++ b/livekit-agents/livekit/agents/job.py @@ -74,7 +74,7 @@ def __init__( Callable[[JobContext, rtc.RemoteParticipant], Coroutine[None, None, None]] ] = [] self._participant_tasks = dict[Tuple[str, Callable], asyncio.Task[None]]() - self._room.on("participant_connected", self._on_participant_connected) + self._room.on("participant_connected", self._participant_available) @property def proc(self) -> JobProcess: @@ -104,6 +104,39 @@ def add_shutdown_callback( ) -> None: self._shutdown_callbacks.append(callback) + async def wait_for_participant( + self, *, identity: str | None = None + ) -> rtc.RemoteParticipant: + """ + Returns a participant that matches the given identity. If identity is None, the first + participant that joins the room will be returned. + If the participant has already joined, the function will return immediately. + """ + if not self._room.isconnected(): + raise RuntimeError("room is not connected") + + fut = asyncio.Future[rtc.RemoteParticipant]() + + for p in self._room.remote_participants.values(): + if ( + identity is None or p.identity == identity + ) and p.kind != rtc.ParticipantKind.PARTICIPANT_KIND_AGENT: + fut.set_result(p) + break + + def _on_participant_connected(p: rtc.RemoteParticipant): + if ( + identity is None or p.identity == identity + ) and p.kind != rtc.ParticipantKind.PARTICIPANT_KIND_AGENT: + self._room.off("participant_connected", _on_participant_connected) + if not fut.done(): + fut.set_result(p) + + if not fut.done(): + self._room.on("participant_connected", _on_participant_connected) + + return await fut + async def connect( self, *, @@ -127,33 +160,20 @@ async def connect( await self._room.connect(self._info.url, self._info.token, options=room_options) self._on_connect() for p in self._room.remote_participants.values(): - self._on_participant_connected(p) + self._participant_available(p) _apply_auto_subscribe_opts(self._room, auto_subscribe) def shutdown(self, reason: str = "") -> None: self._on_shutdown(reason) - def _on_participant_connected(self, p: rtc.RemoteParticipant) -> None: - for coro in self._participant_entrypoints: - if (p.identity, coro) in self._participant_tasks: - logger.warning( - f"a participant has joined before a prior participant task matching the same identity has finished: '{p.identity}'" - ) - task_name = f"part-entry-{p.identity}-{coro.__name__}" - task = asyncio.create_task(coro(self, p), name=task_name) - self._participant_tasks[(p.identity, coro)] = task - task.add_done_callback( - lambda _: self._participant_tasks.pop((p.identity, coro)) - ) - def add_participant_entrypoint( self, entrypoint_fnc: Callable[ [JobContext, rtc.RemoteParticipant], Coroutine[None, None, None] ], ): - """Adds an entrypoint function to be run when a participant that matches the filter joins the room. In cases where + """Adds an entrypoint function to be run when a participant joins the room. In cases where the participant has already joined, the entrypoint will be run immediately. Multiple unique entrypoints can be added and they will each be run in parallel for each participant. """ @@ -163,6 +183,19 @@ def add_participant_entrypoint( self._participant_entrypoints.append(entrypoint_fnc) + def _participant_available(self, p: rtc.RemoteParticipant) -> None: + for coro in self._participant_entrypoints: + if (p.identity, coro) in self._participant_tasks: + logger.warning( + f"a participant has joined before a prior participant task matching the same identity has finished: '{p.identity}'" + ) + task_name = f"part-entry-{p.identity}-{coro.__name__}" + task = asyncio.create_task(coro(self, p), name=task_name) + self._participant_tasks[(p.identity, coro)] = task + task.add_done_callback( + lambda _: self._participant_tasks.pop((p.identity, coro)) + ) + def _apply_auto_subscribe_opts(room: rtc.Room, auto_subscribe: AutoSubscribe) -> None: if auto_subscribe not in (AutoSubscribe.AUDIO_ONLY, AutoSubscribe.VIDEO_ONLY):