Skip to content

Commit

Permalink
add threaded job runners (#684)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Sep 7, 2024
1 parent 87fee72 commit a8b777f
Show file tree
Hide file tree
Showing 14 changed files with 450 additions and 62 deletions.
5 changes: 5 additions & 0 deletions .changeset/hungry-students-end.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

ipc: add threaded job runner
5 changes: 2 additions & 3 deletions examples/simple-color/agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import asyncio
import logging
import os
from dotenv import load_dotenv, dotenv_values
import random

from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import JobContext, WorkerOptions, cli
import random

# Load environment variables
load_dotenv()
Expand Down
3 changes: 2 additions & 1 deletion livekit-agents/livekit/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from . import ipc, llm, stt, tokenize, transcription, tts, utils, vad, voice_assistant
from .job import AutoSubscribe, JobContext, JobProcess, JobRequest
from .job import AutoSubscribe, JobContext, JobExecutorType, JobProcess, JobRequest
from .plugin import Plugin
from .version import __version__
from .worker import Worker, WorkerOptions, WorkerPermissions, WorkerType
Expand All @@ -27,6 +27,7 @@
"JobProcess",
"JobContext",
"JobRequest",
"JobExecutorType",
"AutoSubscribe",
"Plugin",
"ipc",
Expand Down
18 changes: 16 additions & 2 deletions livekit-agents/livekit/agents/ipc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from . import channel, proc_pool, proto, supervised_proc
from . import (
channel,
job_executor,
proc_job_executor,
proc_pool,
proto,
thread_job_executor,
)

__all__ = ["proto", "channel", "proc_pool", "supervised_proc"]
__all__ = [
"proto",
"channel",
"proc_pool",
"proc_job_executor",
"thread_job_executor",
"job_executor",
]
29 changes: 29 additions & 0 deletions livekit-agents/livekit/agents/ipc/job_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from typing import Any, Protocol

from ..job import RunningJobInfo


class JobExecutor(Protocol):
@property
def started(self) -> bool: ...

@property
def start_arguments(self) -> Any | None: ...

@start_arguments.setter
def start_arguments(self, value: Any | None) -> None: ...

@property
def running_job(self) -> RunningJobInfo | None: ...

async def start(self) -> None: ...

async def join(self) -> None: ...

async def initialize(self) -> None: ...

async def aclose(self) -> None: ...

async def launch_job(self, info: RunningJobInfo) -> None: ...
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import socket
import threading
from dataclasses import dataclass
from typing import Optional
from typing import Any, Callable, Optional

from livekit import rtc

Expand Down Expand Up @@ -86,8 +86,8 @@ class JobTask:


def _start_job(
args: proto.ProcStartArgs,
proc: JobProcess,
job_entrypoint_fnc: Callable[[JobContext], Any],
start_req: proto.StartJobRequest,
exit_proc_fut: asyncio.Event,
cch: utils.aio.duplex_unix._AsyncDuplex,
Expand Down Expand Up @@ -131,7 +131,7 @@ def _on_ctx_shutdown(reason: str) -> None:
async def _run_job_task() -> None:
utils.http_context._new_session_ctx()
job_entry_task = asyncio.create_task(
args.job_entrypoint_fnc(job_ctx), name="job_entrypoint"
job_entrypoint_fnc(job_ctx), name="job_entrypoint"
)

async def _warn_not_connected_task():
Expand Down Expand Up @@ -189,7 +189,9 @@ def log_exception(t: asyncio.Task) -> None:


async def _async_main(
args: proto.ProcStartArgs, proc: JobProcess, mp_cch: socket.socket
proc: JobProcess,
job_entrypoint_fnc: Callable[[JobContext], Any],
mp_cch: socket.socket,
) -> None:
cch = await duplex_unix._AsyncDuplex.open(mp_cch)

Expand All @@ -202,7 +204,8 @@ async def _read_ipc_task():
nonlocal job_task
while True:
msg = await channel.arecv_message(cch, proto.IPC_MESSAGES)
no_msg_timeout.reset()
with contextlib.suppress(utils.aio.SleepFinished):
no_msg_timeout.reset()

if isinstance(msg, proto.PingRequest):
pong = proto.PongResponse(
Expand All @@ -212,7 +215,7 @@ async def _read_ipc_task():

if isinstance(msg, proto.StartJobRequest):
assert job_task is None, "job task already running"
job_task = _start_job(args, proc, msg, exit_proc_fut, cch)
job_task = _start_job(proc, job_entrypoint_fnc, msg, exit_proc_fut, cch)

if isinstance(msg, proto.ShutdownRequest):
if job_task is None:
Expand Down Expand Up @@ -246,7 +249,18 @@ def _done_cb(task: asyncio.Task) -> None:
await cch.aclose()


def main(args: proto.ProcStartArgs) -> None:
@dataclass
class ProcStartArgs:
initialize_process_fnc: Callable[[JobProcess], Any]
job_entrypoint_fnc: Callable[[JobContext], Any]
log_cch: socket.socket
mp_cch: socket.socket
asyncio_debug: bool
user_arguments: Any | None = None


def proc_main(args: ProcStartArgs) -> None:
"""main function for the job process when using the ProcessJobRunner"""
root_logger = logging.getLogger()
root_logger.setLevel(logging.NOTSET)

Expand Down Expand Up @@ -275,7 +289,8 @@ def main(args: proto.ProcStartArgs) -> None:
channel.send_message(cch, proto.InitializeResponse())

main_task = loop.create_task(
_async_main(args, job_proc, cch.detach()), name="job_proc_main"
_async_main(job_proc, args.job_entrypoint_fnc, cch.detach()),
name="job_proc_main",
)
while not main_task.done():
try:
Expand All @@ -286,6 +301,52 @@ def main(args: proto.ProcStartArgs) -> None:
except duplex_unix.DuplexClosed:
pass
finally:
cch.close()
log_handler.close()
loop.run_until_complete(loop.shutdown_default_executor())


@dataclass
class ThreadStartArgs:
mp_cch: socket.socket
initialize_process_fnc: Callable[[JobProcess], Any]
job_entrypoint_fnc: Callable[[JobContext], Any]
user_arguments: Any | None
asyncio_debug: bool
join_fnc: Callable[[], None]


def thread_main(
args: ThreadStartArgs,
) -> None:
"""main function for the job process when using the ThreadedJobRunner"""
tid = threading.get_native_id()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.set_debug(args.asyncio_debug)
loop.slow_callback_duration = 0.1 # 100ms

cch = duplex_unix._Duplex.open(args.mp_cch)
try:
init_req = channel.recv_message(cch, proto.IPC_MESSAGES)
assert isinstance(
init_req, proto.InitializeRequest
), "first message must be InitializeRequest"
job_proc = JobProcess(start_arguments=args.user_arguments)

logger.debug("initializing job runner", extra={"tid": tid})
args.initialize_process_fnc(job_proc)
logger.debug("job runner initialized", extra={"tid": tid})
channel.send_message(cch, proto.InitializeResponse())

main_task = loop.create_task(
_async_main(job_proc, args.job_entrypoint_fnc, cch.detach()),
name="job_proc_main",
)
loop.run_until_complete(main_task)
except duplex_unix.DuplexClosed:
pass
except Exception:
logger.exception("error while running job process", extra={"tid": tid})
finally:
args.join_fnc()
loop.run_until_complete(loop.shutdown_default_executor())
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..job import JobContext, JobProcess, RunningJobInfo
from ..log import logger
from ..utils.aio import duplex_unix
from . import channel, proc_main, proto
from . import channel, job_main, proto


class LogQueueListener:
Expand Down Expand Up @@ -69,7 +69,7 @@ class _ProcOpts:
close_timeout: float


class SupervisedProc:
class ProcJobExecutor:
def __init__(
self,
*,
Expand Down Expand Up @@ -155,7 +155,7 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None:
log_listener = LogQueueListener(log_pch, _add_proc_ctx_log)
log_listener.start()

self._proc_args = proto.ProcStartArgs(
self._proc_args = job_main.ProcStartArgs(
initialize_process_fnc=self._opts.initialize_process_fnc,
job_entrypoint_fnc=self._opts.job_entrypoint_fnc,
log_cch=mp_log_cch,
Expand All @@ -165,7 +165,7 @@ def _add_proc_ctx_log(record: logging.LogRecord) -> None:
)

self._proc = self._opts.mp_ctx.Process( # type: ignore
target=proc_main.main, args=(self._proc_args,), name="job_proc"
target=job_main.proc_main, args=(self._proc_args,), name="job_proc"
)

self._proc.start()
Expand Down
54 changes: 35 additions & 19 deletions livekit-agents/livekit/agents/ipc/proc_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import Any, Awaitable, Callable, Literal

from .. import utils
from ..job import JobContext, JobProcess, RunningJobInfo
from ..job import JobContext, JobExecutorType, JobProcess, RunningJobInfo
from ..log import logger
from ..utils import aio
from .supervised_proc import SupervisedProc
from . import proc_job_executor, thread_job_executor
from .job_executor import JobExecutor

EventTypes = Literal[
"process_created", "process_started", "process_ready", "process_closed"
Expand All @@ -26,10 +27,12 @@ def __init__(
num_idle_processes: int,
initialize_timeout: float,
close_timeout: float,
job_executor_type: JobExecutorType,
mp_ctx: BaseContext,
loop: asyncio.AbstractEventLoop,
) -> None:
super().__init__()
self._job_executor_type = job_executor_type
self._mp_ctx = mp_ctx
self._initialize_process_fnc = initialize_process_fnc
self._job_entrypoint_fnc = job_entrypoint_fnc
Expand All @@ -39,20 +42,20 @@ def __init__(

self._init_sem = asyncio.Semaphore(MAX_CONCURRENT_INITIALIZATIONS)
self._proc_needed_sem = asyncio.Semaphore(num_idle_processes)
self._warmed_proc_queue = asyncio.Queue[SupervisedProc]()
self._processes: list[SupervisedProc] = []
self._warmed_proc_queue = asyncio.Queue[JobExecutor]()
self._executors: list[JobExecutor] = []
self._started = False
self._closed = False

@property
def processes(self) -> list[SupervisedProc]:
return self._processes
def processes(self) -> list[JobExecutor]:
return self._executors

def get_by_job_id(self, job_id: str) -> SupervisedProc | None:
def get_by_job_id(self, job_id: str) -> JobExecutor | None:
return next(
(
x
for x in self._processes
for x in self._executors
if x.running_job and x.running_job.job.id == job_id
),
None,
Expand All @@ -79,16 +82,29 @@ async def launch_job(self, info: RunningJobInfo) -> None:

@utils.log_exceptions(logger=logger)
async def _proc_watch_task(self) -> None:
proc = SupervisedProc(
initialize_process_fnc=self._initialize_process_fnc,
job_entrypoint_fnc=self._job_entrypoint_fnc,
initialize_timeout=self._initialize_timeout,
close_timeout=self._close_timeout,
mp_ctx=self._mp_ctx,
loop=self._loop,
)
proc: JobExecutor
if self._job_executor_type == JobExecutorType.THREAD:
proc = thread_job_executor.ThreadJobExecutor(
initialize_process_fnc=self._initialize_process_fnc,
job_entrypoint_fnc=self._job_entrypoint_fnc,
initialize_timeout=self._initialize_timeout,
close_timeout=self._close_timeout,
loop=self._loop,
)
elif self._job_executor_type == JobExecutorType.PROCESS:
proc = proc_job_executor.ProcJobExecutor(
initialize_process_fnc=self._initialize_process_fnc,
job_entrypoint_fnc=self._job_entrypoint_fnc,
initialize_timeout=self._initialize_timeout,
close_timeout=self._close_timeout,
mp_ctx=self._mp_ctx,
loop=self._loop,
)
else:
raise ValueError(f"unsupported job executor: {self._job_executor_type}")

try:
self._processes.append(proc)
self._executors.append(proc)

async with self._init_sem:
if self._closed:
Expand All @@ -109,7 +125,7 @@ async def _proc_watch_task(self) -> None:
await proc.join()
self.emit("process_closed", proc)
finally:
self._processes.remove(proc)
self._executors.remove(proc)

@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
Expand All @@ -121,5 +137,5 @@ async def _main_task(self) -> None:
watch_tasks.append(task)
task.add_done_callback(watch_tasks.remove)
except asyncio.CancelledError:
await asyncio.gather(*[proc.aclose() for proc in self._processes])
await asyncio.gather(*[proc.aclose() for proc in self._executors])
await asyncio.gather(*watch_tasks)
Loading

0 comments on commit a8b777f

Please sign in to comment.