Skip to content

Commit

Permalink
(#30): Refactor worker for readability, extensibility & configurability
Browse files Browse the repository at this point in the history
Closes #30

* Refactor: Define WorkerManager & GruntWorker

WorkerManager is the one who receives tasks and delegate them to its
GruntWorkers running in background. GruntWorker executes the tasks
and publish the result back to the queue. Both of these should implement
the Interface IWorker. The main logic for the worker is in the `_main_loop`.

* Refactor: pubsub client & concurrency manager

In the future we want to enable users to configure and switch to
a different pubsub client & concurrency manager if need be. We start
with Redis as default pubsub and multiprocessing as default concurrency
manager. Users should be able to configure these using envronment
variables.

* Use better interface

We use Protocols as interfaces for worker, pubsub, concurrency manager

* Refactor: Use pytest fixture for worker

* Fix warning & zombie child procs on sigterm/sigkill

Before, `./test.sh && pidof $(which python)` after `./test.sh`
will give a list of pids which means we are not properly
killing child processes.

With this this change, we do not see zombie child processes
anymore.

We fix this making sure the worker manager handle TERM & INT signals
and propagate it to the workers.

Some helpful references regarding handling signals:
* https://stackoverflow.com/questions/42628795/indirectly-stopping-a-python-asyncio-event-loop-through-sigterm-has-no-effect
* https://stackoverflow.com/questions/67823770/how-to-propagate-sigterm-to-children-created-via-subprocess

* Split worker.py into interfaces, pubsub & concurrency_manager

worker.py is now split into:

├── interfaces.py
├── pubsub.py
└── concurrency_manager.py

all of which should be able to be imported by any worker.py or main.py.
Hopefully this will make the code more organized and well-abstracted.

* Attempt to ensure child processes are covered in unit tests

The WorkerManager processes seems to be included in coverage but
GruntWorker processes are still not (I guess because they are
grandchild processes and coverage doesn't handle that?)

* Make main.py more DRY: Re-use PubSub facade from pubsub.py

* As a positive side effect, also closes #27

Small extras:
* Re-organize test files
* Split test_cli.py to test_cli.py and test_worker.py
* Ignore __main__.py from coverage since it iss not coverable anyways
  • Loading branch information
imranariffin committed Sep 19, 2022
1 parent 136ae39 commit 6b1b2a8
Show file tree
Hide file tree
Showing 19 changed files with 723 additions and 273 deletions.
4 changes: 4 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
[run]
source = src/
parallel = True
concurrency = multiprocessing
omit =
src/aiotaskq/__main__.py
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ import aiotaskq
def some_task(b: int) -> int:
# Some task with high cpu usage
def _naive_fib(n: int) -> int:
if n <= 0:
return 0
elif n <= 2:
if n <= 2:
return 1
return _naive_fib(n - 1) + _naive_fib(n - 2)
return _naive_fib(b)
Expand Down
4 changes: 1 addition & 3 deletions src/aiotaskq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
def some_task(b: int) -> int:
# Some task with high cpu usage
def _naive_fib(n: int) -> int:
if n <= 1:
if n <= 2:
return 1
elif n <= 2:
return 2
return _naive_fib(n - 1) + _naive_fib(n - 2)
return _naive_fib(b)
Expand Down
19 changes: 14 additions & 5 deletions src/aiotaskq/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,30 @@

#!/usr/bin/env python

import asyncio
import typing as t

import typer

from aiotaskq.worker import Defaults, worker
from .interfaces import ConcurrencyType
from .worker import Defaults, run_worker_forever

cli = typer.Typer()


@cli.command(name="worker")
def worker_command(app: str, concurrency: t.Optional[int] = Defaults.concurrency):
def worker_command(
app: str,
concurrency: t.Optional[int] = Defaults.concurrency,
poll_interval_s: t.Optional[float] = Defaults.poll_interval_s,
concurrency_type: t.Optional[ConcurrencyType] = Defaults.concurrency_type,
):
"""Command to start workers."""
loop = asyncio.get_event_loop()
loop.run_until_complete(worker(app_import_path=app, concurrency=concurrency))
run_worker_forever(
app_import_path=app,
concurrency=concurrency,
concurrency_type=concurrency_type,
poll_interval_s=poll_interval_s,
)


@cli.command(name="metric")
Expand Down
51 changes: 51 additions & 0 deletions src/aiotaskq/concurrency_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from functools import cached_property
import logging
import multiprocessing
import os
import typing as t

from .exceptions import ConcurrencyTypeNotSupported
from .interfaces import ConcurrencyType, IConcurrencyManager, IProcess


class ConcurrencyManager:
"""The user-facing facade for creating the right concurrency manager implementation."""

_instance: "IConcurrencyManager"

@classmethod
def get(cls, concurrency_type: str, concurrency) -> IConcurrencyManager:
if cls._instance:
return cls._instance
if concurrency_type == ConcurrencyType.MULTIPROCESSING:
cls._instance = MultiProcessing(concurrency=concurrency)
return cls._instance
raise ConcurrencyTypeNotSupported(
f'Concurrency type "{concurrency_type}" is not yet supported.'
)


class MultiProcessing:
"""Implementation of a ConcurrencyManager that uses the `multiprocess` built-in module."""

def __init__(self, concurrency: int) -> None:
self.concurrency = concurrency
self.processes: dict[int, IProcess] = {}

def start(self, func: t.Callable, *args: t.ParamSpecArgs) -> None:
"""Start each processes under management."""
for _ in range(self.concurrency):
proc = multiprocessing.Process(target=func, args=args)
proc.start()
assert proc.pid is not None
self.processes[proc.pid] = proc

def terminate(self) -> None:
"""Terminate each process under management."""
for proc in self.processes.values():
self._logger.debug("Sending signal TERM to back worker process [pid=%s]", proc.pid)
proc.terminate()

@cached_property
def _logger(self):
return logging.getLogger(f"[{os.getpid()}] [{self.__class__.__qualname__}]")
8 changes: 8 additions & 0 deletions src/aiotaskq/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@

class ModuleInvalidForTask(Exception):
"""Attempt to convert to task a function in an invalid module."""


class UrlNotSupported(Exception):
"""This url is currently not supported."""


class ConcurrencyTypeNotSupported(Exception):
"""This concurrency type is currently not supported."""
116 changes: 116 additions & 0 deletions src/aiotaskq/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Define all interfaces for the library.
Interfaces are mainly typing.Protocol classes, but may also include
other declarative classes like enums or Types.
"""

import enum
import typing as t


Message = t.Union[str, bytes]


class PollResponse(t.TypedDict):
"""Define the dictionary returned from a pubsub."""

type: str
data: Message
pattern: t.Optional[str]
channel: bytes


class IProcess(t.Protocol):
"""
Define the interface for a process used in the library.
It's more or less the same as the `multiprocessing.Process` except this
one only has attributes that are necessary for the library, and also has
slightly different typing e.g. pid in our case is always an `int`, whereas
the one from `multiprocessing.Process` is `Optional[int]`. This way we're
not limited to `multiprocessing.Process` and may switch to another implementation
if needed.
"""

@property
def pid(self) -> t.Optional[int]:
"""Return the process id (pid)."""

def start(self):
"""Start running the process."""

def terminate(self):
"""Send TERM signal to the process."""


class ConcurrencyType(str, enum.Enum):
"""Define supported concurrency types."""

MULTIPROCESSING = "multiprocessing"


class IConcurrencyManager(t.Protocol):
"""
Define the interface of a concurrency manager.
It should be able to start x number of processes given & terminate them.
"""

concurrency: int
processes: dict[int, IProcess]

def __init__(self, concurrency: int) -> None:
"""Initialize the concurrency manager."""

def start(self, func: t.Callable, *args: t.ParamSpecArgs) -> None:
"""Start each process under management."""

def terminate(self) -> None:
"""Terminate each process under management."""


class IPubSub(t.Protocol):
def __init__(self, url: str, poll_interval_s: float, *args, **kwargs):
"""Initialize the pubsub class."""

async def __aenter__(self) -> "IPubSub":
"""Instantiate/start resources when entering the async context."""

async def __aexit__(self, exc_type, exc_value, traceback) -> None:
"""Close resources when entering the async context."""

async def publish(self, channel: str, message: Message) -> None:
"""Publish the given messaage to the given channel."""

async def subscribe(self, channel: str) -> None:
"""Start subscribing to the given channel."""

async def poll(self) -> PollResponse:
"""Poll for new message from the subscribed channel, and return it."""


class IWorker(t.Protocol):
"""
Define the interface for a worker.
It should also be tied to a specific app.
It should be able to subscribe, poll and publish messages to the other worker.
"""

pubsub: IPubSub
app_import_path: str

def run_forever(self) -> None:
"""Run the worker forever in a loop."""


class IWorkerManager(IWorker):
"""
Define the interface for a worker manager.
This is similar to a worker, but has more authority since it is the one
one who create and kill other workers via its concurrency manager.
"""

concurrency_manager: IConcurrencyManager
47 changes: 19 additions & 28 deletions src/aiotaskq/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Module to define the main logic of the library."""

import asyncio
import inspect
import json
import logging
from types import ModuleType
import typing as t
import uuid

import aioredis

from aiotaskq.constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL
from aiotaskq.exceptions import ModuleInvalidForTask
from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL
from .exceptions import ModuleInvalidForTask
from .interfaces import IPubSub, PollResponse
from .pubsub import PubSub

RT = t.TypeVar("RT")
P = t.ParamSpec("P")
Expand All @@ -27,23 +26,22 @@ class AsyncResult(t.Generic[RT]):
To get the result of corresponding task, use `.get()`.
"""

pubsub: IPubSub
_result: RT
_completed: bool = False
_task_id: str

def __init__(self, task_id: str) -> None:
"""Store task_id in AsyncResult instance."""
self._task_id = task_id
self.pubsub = PubSub.get(url=REDIS_URL, poll_interval_s=0.01)

async def get(self) -> RT:
"""Return the result of the task once finished."""
redis_client = aioredis.from_url(REDIS_URL)
async with redis_client.pubsub() as pubsub:
message: t.Optional[dict] = None
while message is None:
await pubsub.subscribe(RESULTS_CHANNEL_TEMPLATE.format(task_id=self._task_id))
message = await pubsub.get_message(ignore_subscribe_messages=True)
await asyncio.sleep(0.1)
async with self.pubsub as pubsub:
message: PollResponse
await pubsub.subscribe(RESULTS_CHANNEL_TEMPLATE.format(task_id=self._task_id))
message = await self.pubsub.poll()
logger.debug("Message: %s", message)
_result: RT = json.loads(message["data"])
return _result
Expand Down Expand Up @@ -113,14 +111,16 @@ async def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> RT:
"kwargs": kwargs,
}
)
publisher: aioredis.Redis = _get_redis_client()

logger.debug("Publishing task [task_id=%s, message=%s]", task_id, message)
await publisher.publish(TASKS_CHANNEL, message=message)
pubsub_ = PubSub.get(
url=REDIS_URL, poll_interval_s=0.01, max_connections=10, decode_responses=True
)
async with pubsub_ as pubsub:
logger.debug("Publishing task [task_id=%s, message=%s]", task_id, message)
await pubsub.publish(TASKS_CHANNEL, message=message)

logger.debug("Retrieving result for task [task_id=%s]", task_id)
async_result: AsyncResult[RT] = AsyncResult(task_id=task_id)
result: RT = await async_result.get()
logger.debug("Retrieving result for task [task_id=%s]", task_id)
async_result: AsyncResult[RT] = AsyncResult(task_id=task_id)
result: RT = await async_result.get()

return result

Expand All @@ -145,12 +145,3 @@ def task(func: t.Callable[P, RT]) -> Task[P, RT]:
task_.__qualname__ = f"{module_path}.{func.__name__}"
task_.__module__ = module_path
return task_


_REDIS_CLIENT: t.Optional[aioredis.Redis] = None


def _get_redis_client() -> aioredis.Redis:
if _REDIS_CLIENT is not None:
return _REDIS_CLIENT
return aioredis.from_url(REDIS_URL, max_connections=10, decode_responses=True)
Loading

0 comments on commit 6b1b2a8

Please sign in to comment.