From 6b1b2a8edd87ed3f52b3794477f5560b299f3308 Mon Sep 17 00:00:00 2001 From: Imran Ariffin Date: Tue, 30 Aug 2022 08:52:00 -0400 Subject: [PATCH] (#30): Refactor worker for readability, extensibility & configurability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #30 * Refactor: Define WorkerManager & GruntWorker WorkerManager is the one who receives tasks and delegate them to its GruntWorkers running in background. GruntWorker executes the tasks and publish the result back to the queue. Both of these should implement the Interface IWorker. The main logic for the worker is in the `_main_loop`. * Refactor: pubsub client & concurrency manager In the future we want to enable users to configure and switch to a different pubsub client & concurrency manager if need be. We start with Redis as default pubsub and multiprocessing as default concurrency manager. Users should be able to configure these using envronment variables. * Use better interface We use Protocols as interfaces for worker, pubsub, concurrency manager * Refactor: Use pytest fixture for worker * Fix warning & zombie child procs on sigterm/sigkill Before, `./test.sh && pidof $(which python)` after `./test.sh` will give a list of pids which means we are not properly killing child processes. With this this change, we do not see zombie child processes anymore. We fix this making sure the worker manager handle TERM & INT signals and propagate it to the workers. Some helpful references regarding handling signals: * https://stackoverflow.com/questions/42628795/indirectly-stopping-a-python-asyncio-event-loop-through-sigterm-has-no-effect * https://stackoverflow.com/questions/67823770/how-to-propagate-sigterm-to-children-created-via-subprocess * Split worker.py into interfaces, pubsub & concurrency_manager worker.py is now split into: ├── interfaces.py ├── pubsub.py └── concurrency_manager.py all of which should be able to be imported by any worker.py or main.py. Hopefully this will make the code more organized and well-abstracted. * Attempt to ensure child processes are covered in unit tests The WorkerManager processes seems to be included in coverage but GruntWorker processes are still not (I guess because they are grandchild processes and coverage doesn't handle that?) * Make main.py more DRY: Re-use PubSub facade from pubsub.py * As a positive side effect, also closes #27 Small extras: * Re-organize test files * Split test_cli.py to test_cli.py and test_worker.py * Ignore __main__.py from coverage since it iss not coverable anyways --- .coveragerc | 4 + README.md | 4 +- src/aiotaskq/__init__.py | 4 +- src/aiotaskq/__main__.py | 19 +- src/aiotaskq/concurrency_manager.py | 51 +++ src/aiotaskq/exceptions.py | 8 + src/aiotaskq/interfaces.py | 116 +++++++ src/aiotaskq/main.py | 47 ++- src/aiotaskq/pubsub.py | 67 ++++ src/aiotaskq/tests/__init__.py | 0 src/aiotaskq/tests/apps/simple_app.py | 2 - src/aiotaskq/tests/conftest.py | 46 +++ src/aiotaskq/tests/test_cli.py | 90 +----- .../tests/test_concurrency_manager.py | 20 ++ src/aiotaskq/tests/test_integration.py | 37 +-- src/aiotaskq/tests/test_pubsub.py | 18 ++ src/aiotaskq/tests/test_worker.py | 163 ++++++++++ src/aiotaskq/worker.py | 296 ++++++++++-------- test.sh | 4 + 19 files changed, 723 insertions(+), 273 deletions(-) create mode 100644 src/aiotaskq/concurrency_manager.py create mode 100644 src/aiotaskq/interfaces.py create mode 100644 src/aiotaskq/pubsub.py create mode 100644 src/aiotaskq/tests/__init__.py create mode 100644 src/aiotaskq/tests/conftest.py create mode 100644 src/aiotaskq/tests/test_concurrency_manager.py create mode 100644 src/aiotaskq/tests/test_pubsub.py create mode 100644 src/aiotaskq/tests/test_worker.py diff --git a/.coveragerc b/.coveragerc index 4af8d36..45875af 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,2 +1,6 @@ [run] source = src/ +parallel = True +concurrency = multiprocessing +omit = + src/aiotaskq/__main__.py diff --git a/README.md b/README.md index 18eab3a..8a5a56b 100644 --- a/README.md +++ b/README.md @@ -37,9 +37,7 @@ import aiotaskq def some_task(b: int) -> int: # Some task with high cpu usage def _naive_fib(n: int) -> int: - if n <= 0: - return 0 - elif n <= 2: + if n <= 2: return 1 return _naive_fib(n - 1) + _naive_fib(n - 2) return _naive_fib(b) diff --git a/src/aiotaskq/__init__.py b/src/aiotaskq/__init__.py index 36d7cac..22ac8bb 100644 --- a/src/aiotaskq/__init__.py +++ b/src/aiotaskq/__init__.py @@ -11,10 +11,8 @@ def some_task(b: int) -> int: # Some task with high cpu usage def _naive_fib(n: int) -> int: - if n <= 1: + if n <= 2: return 1 - elif n <= 2: - return 2 return _naive_fib(n - 1) + _naive_fib(n - 2) return _naive_fib(b) diff --git a/src/aiotaskq/__main__.py b/src/aiotaskq/__main__.py index 689dd33..e45031c 100755 --- a/src/aiotaskq/__main__.py +++ b/src/aiotaskq/__main__.py @@ -2,21 +2,30 @@ #!/usr/bin/env python -import asyncio import typing as t import typer -from aiotaskq.worker import Defaults, worker +from .interfaces import ConcurrencyType +from .worker import Defaults, run_worker_forever cli = typer.Typer() @cli.command(name="worker") -def worker_command(app: str, concurrency: t.Optional[int] = Defaults.concurrency): +def worker_command( + app: str, + concurrency: t.Optional[int] = Defaults.concurrency, + poll_interval_s: t.Optional[float] = Defaults.poll_interval_s, + concurrency_type: t.Optional[ConcurrencyType] = Defaults.concurrency_type, +): """Command to start workers.""" - loop = asyncio.get_event_loop() - loop.run_until_complete(worker(app_import_path=app, concurrency=concurrency)) + run_worker_forever( + app_import_path=app, + concurrency=concurrency, + concurrency_type=concurrency_type, + poll_interval_s=poll_interval_s, + ) @cli.command(name="metric") diff --git a/src/aiotaskq/concurrency_manager.py b/src/aiotaskq/concurrency_manager.py new file mode 100644 index 0000000..371f185 --- /dev/null +++ b/src/aiotaskq/concurrency_manager.py @@ -0,0 +1,51 @@ +from functools import cached_property +import logging +import multiprocessing +import os +import typing as t + +from .exceptions import ConcurrencyTypeNotSupported +from .interfaces import ConcurrencyType, IConcurrencyManager, IProcess + + +class ConcurrencyManager: + """The user-facing facade for creating the right concurrency manager implementation.""" + + _instance: "IConcurrencyManager" + + @classmethod + def get(cls, concurrency_type: str, concurrency) -> IConcurrencyManager: + if cls._instance: + return cls._instance + if concurrency_type == ConcurrencyType.MULTIPROCESSING: + cls._instance = MultiProcessing(concurrency=concurrency) + return cls._instance + raise ConcurrencyTypeNotSupported( + f'Concurrency type "{concurrency_type}" is not yet supported.' + ) + + +class MultiProcessing: + """Implementation of a ConcurrencyManager that uses the `multiprocess` built-in module.""" + + def __init__(self, concurrency: int) -> None: + self.concurrency = concurrency + self.processes: dict[int, IProcess] = {} + + def start(self, func: t.Callable, *args: t.ParamSpecArgs) -> None: + """Start each processes under management.""" + for _ in range(self.concurrency): + proc = multiprocessing.Process(target=func, args=args) + proc.start() + assert proc.pid is not None + self.processes[proc.pid] = proc + + def terminate(self) -> None: + """Terminate each process under management.""" + for proc in self.processes.values(): + self._logger.debug("Sending signal TERM to back worker process [pid=%s]", proc.pid) + proc.terminate() + + @cached_property + def _logger(self): + return logging.getLogger(f"[{os.getpid()}] [{self.__class__.__qualname__}]") diff --git a/src/aiotaskq/exceptions.py b/src/aiotaskq/exceptions.py index c179bc7..4f58a4f 100644 --- a/src/aiotaskq/exceptions.py +++ b/src/aiotaskq/exceptions.py @@ -7,3 +7,11 @@ class ModuleInvalidForTask(Exception): """Attempt to convert to task a function in an invalid module.""" + + +class UrlNotSupported(Exception): + """This url is currently not supported.""" + + +class ConcurrencyTypeNotSupported(Exception): + """This concurrency type is currently not supported.""" diff --git a/src/aiotaskq/interfaces.py b/src/aiotaskq/interfaces.py new file mode 100644 index 0000000..e2fc0d0 --- /dev/null +++ b/src/aiotaskq/interfaces.py @@ -0,0 +1,116 @@ +""" +Define all interfaces for the library. + +Interfaces are mainly typing.Protocol classes, but may also include +other declarative classes like enums or Types. +""" + +import enum +import typing as t + + +Message = t.Union[str, bytes] + + +class PollResponse(t.TypedDict): + """Define the dictionary returned from a pubsub.""" + + type: str + data: Message + pattern: t.Optional[str] + channel: bytes + + +class IProcess(t.Protocol): + """ + Define the interface for a process used in the library. + + It's more or less the same as the `multiprocessing.Process` except this + one only has attributes that are necessary for the library, and also has + slightly different typing e.g. pid in our case is always an `int`, whereas + the one from `multiprocessing.Process` is `Optional[int]`. This way we're + not limited to `multiprocessing.Process` and may switch to another implementation + if needed. + """ + + @property + def pid(self) -> t.Optional[int]: + """Return the process id (pid).""" + + def start(self): + """Start running the process.""" + + def terminate(self): + """Send TERM signal to the process.""" + + +class ConcurrencyType(str, enum.Enum): + """Define supported concurrency types.""" + + MULTIPROCESSING = "multiprocessing" + + +class IConcurrencyManager(t.Protocol): + """ + Define the interface of a concurrency manager. + + It should be able to start x number of processes given & terminate them. + """ + + concurrency: int + processes: dict[int, IProcess] + + def __init__(self, concurrency: int) -> None: + """Initialize the concurrency manager.""" + + def start(self, func: t.Callable, *args: t.ParamSpecArgs) -> None: + """Start each process under management.""" + + def terminate(self) -> None: + """Terminate each process under management.""" + + +class IPubSub(t.Protocol): + def __init__(self, url: str, poll_interval_s: float, *args, **kwargs): + """Initialize the pubsub class.""" + + async def __aenter__(self) -> "IPubSub": + """Instantiate/start resources when entering the async context.""" + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + """Close resources when entering the async context.""" + + async def publish(self, channel: str, message: Message) -> None: + """Publish the given messaage to the given channel.""" + + async def subscribe(self, channel: str) -> None: + """Start subscribing to the given channel.""" + + async def poll(self) -> PollResponse: + """Poll for new message from the subscribed channel, and return it.""" + + +class IWorker(t.Protocol): + """ + Define the interface for a worker. + + It should also be tied to a specific app. + It should be able to subscribe, poll and publish messages to the other worker. + """ + + pubsub: IPubSub + app_import_path: str + + def run_forever(self) -> None: + """Run the worker forever in a loop.""" + + +class IWorkerManager(IWorker): + """ + Define the interface for a worker manager. + + This is similar to a worker, but has more authority since it is the one + one who create and kill other workers via its concurrency manager. + """ + + concurrency_manager: IConcurrencyManager diff --git a/src/aiotaskq/main.py b/src/aiotaskq/main.py index 3893416..31ab455 100644 --- a/src/aiotaskq/main.py +++ b/src/aiotaskq/main.py @@ -1,6 +1,5 @@ """Module to define the main logic of the library.""" -import asyncio import inspect import json import logging @@ -8,10 +7,10 @@ import typing as t import uuid -import aioredis - -from aiotaskq.constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL -from aiotaskq.exceptions import ModuleInvalidForTask +from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL +from .exceptions import ModuleInvalidForTask +from .interfaces import IPubSub, PollResponse +from .pubsub import PubSub RT = t.TypeVar("RT") P = t.ParamSpec("P") @@ -27,6 +26,7 @@ class AsyncResult(t.Generic[RT]): To get the result of corresponding task, use `.get()`. """ + pubsub: IPubSub _result: RT _completed: bool = False _task_id: str @@ -34,16 +34,14 @@ class AsyncResult(t.Generic[RT]): def __init__(self, task_id: str) -> None: """Store task_id in AsyncResult instance.""" self._task_id = task_id + self.pubsub = PubSub.get(url=REDIS_URL, poll_interval_s=0.01) async def get(self) -> RT: """Return the result of the task once finished.""" - redis_client = aioredis.from_url(REDIS_URL) - async with redis_client.pubsub() as pubsub: - message: t.Optional[dict] = None - while message is None: - await pubsub.subscribe(RESULTS_CHANNEL_TEMPLATE.format(task_id=self._task_id)) - message = await pubsub.get_message(ignore_subscribe_messages=True) - await asyncio.sleep(0.1) + async with self.pubsub as pubsub: + message: PollResponse + await pubsub.subscribe(RESULTS_CHANNEL_TEMPLATE.format(task_id=self._task_id)) + message = await self.pubsub.poll() logger.debug("Message: %s", message) _result: RT = json.loads(message["data"]) return _result @@ -113,14 +111,16 @@ async def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> RT: "kwargs": kwargs, } ) - publisher: aioredis.Redis = _get_redis_client() - - logger.debug("Publishing task [task_id=%s, message=%s]", task_id, message) - await publisher.publish(TASKS_CHANNEL, message=message) + pubsub_ = PubSub.get( + url=REDIS_URL, poll_interval_s=0.01, max_connections=10, decode_responses=True + ) + async with pubsub_ as pubsub: + logger.debug("Publishing task [task_id=%s, message=%s]", task_id, message) + await pubsub.publish(TASKS_CHANNEL, message=message) - logger.debug("Retrieving result for task [task_id=%s]", task_id) - async_result: AsyncResult[RT] = AsyncResult(task_id=task_id) - result: RT = await async_result.get() + logger.debug("Retrieving result for task [task_id=%s]", task_id) + async_result: AsyncResult[RT] = AsyncResult(task_id=task_id) + result: RT = await async_result.get() return result @@ -145,12 +145,3 @@ def task(func: t.Callable[P, RT]) -> Task[P, RT]: task_.__qualname__ = f"{module_path}.{func.__name__}" task_.__module__ = module_path return task_ - - -_REDIS_CLIENT: t.Optional[aioredis.Redis] = None - - -def _get_redis_client() -> aioredis.Redis: - if _REDIS_CLIENT is not None: - return _REDIS_CLIENT - return aioredis.from_url(REDIS_URL, max_connections=10, decode_responses=True) diff --git a/src/aiotaskq/pubsub.py b/src/aiotaskq/pubsub.py new file mode 100644 index 0000000..0c8778a --- /dev/null +++ b/src/aiotaskq/pubsub.py @@ -0,0 +1,67 @@ +import asyncio +import typing as t + +import aioredis as redis + +from .exceptions import UrlNotSupported +from .interfaces import IPubSub, Message, PollResponse + + +class PubSub: + """The user-facing facade for creating the right pubsub implementation based on url.""" + + _instance: t.Optional[IPubSub] = None + + @classmethod + def get(cls, url: str, poll_interval_s: float, **kwargs) -> IPubSub: + """ + Return the correct pubsub instance based on url. + + Currently supports only Redis (url="redis*"). + """ + if cls._instance: + return cls._instance + + if url.startswith("redis"): + cls._instance = PubSubRedis(url=url, poll_interval_s=poll_interval_s, **kwargs) + return cls._instance + raise UrlNotSupported(f'Url "{url}" is currently not supported.') + + +class PubSubRedis: + """Redis implementation of a pubsub.""" + + def __init__(self, url: str, poll_interval_s: float, **kwargs) -> None: + self._url = url + self._poll_interval_s = poll_interval_s + self._redis_client = redis.Redis.from_url(url=self._url, **kwargs) + self._redis_pubsub = self._redis_client.pubsub() + + async def __aenter__(self) -> "PubSubRedis": + """Initialize redis client and redis pubsub client on entering the async context.""" + await self._redis_client.__aenter__() + await self._redis_pubsub.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + """Close redis client and redis pubsub client on exiting the async context.""" + await self._redis_pubsub.__aexit__(exc_type, exc_value, traceback) + await self._redis_client.__aexit__(exc_type, exc_value, traceback) + + async def publish(self, channel: str, message: Message) -> None: + """Publish the given message to the given channel.""" + await self._redis_client.publish(channel=channel, message=message) + + async def subscribe(self, channel: str) -> None: + """Start subscribing to the given channel.""" + await self._redis_pubsub.subscribe(channel) + + async def poll(self) -> PollResponse: + """Keep requesting for a new message on some interval, and return one only if available.""" + message: t.Optional[Message] + while True: + message = await self._redis_pubsub.get_message(ignore_subscribe_messages=True) + if message is not None: + break + await asyncio.sleep(self._poll_interval_s) + return message diff --git a/src/aiotaskq/tests/__init__.py b/src/aiotaskq/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aiotaskq/tests/apps/simple_app.py b/src/aiotaskq/tests/apps/simple_app.py index 6d10aa1..e49b5ba 100644 --- a/src/aiotaskq/tests/apps/simple_app.py +++ b/src/aiotaskq/tests/apps/simple_app.py @@ -20,8 +20,6 @@ def join(ls: list, delimiter: str = ",") -> str: def some_task(b: int) -> int: # Some task with high cpu usage def _naive_fib(n: int) -> int: - if n <= 0: - return 0 if n <= 2: return 1 return _naive_fib(n - 1) + _naive_fib(n - 2) diff --git a/src/aiotaskq/tests/conftest.py b/src/aiotaskq/tests/conftest.py new file mode 100644 index 0000000..476888c --- /dev/null +++ b/src/aiotaskq/tests/conftest.py @@ -0,0 +1,46 @@ +import asyncio +import multiprocessing +import typing as t + +import pytest + +from aiotaskq.interfaces import ConcurrencyType +from aiotaskq.worker import Defaults, run_worker_forever + + +class WorkerFixture: + proc: multiprocessing.Process + + async def start( + self, + app: str, + concurrency: t.Optional[int] = Defaults.concurrency, + concurrency_type: t.Optional[ConcurrencyType] = Defaults.concurrency_type, + poll_interval_s: t.Optional[float] = Defaults.poll_interval_s, + ) -> None: + proc = multiprocessing.Process( + target=lambda: run_worker_forever( + app_import_path=app, + concurrency=concurrency, + concurrency_type=concurrency_type, + poll_interval_s=poll_interval_s, + ) + ) + proc.start() + # Wait for worker to be ready, otherwise some tests will get stuck, because + # we're publishing a task before the worker managed to suscribe. You can + # replicate this by adding `await asyncio.sleep(1)` right before the line in + # in worker.py where the worker manager calls `await pubsub.subscribe()`. + await asyncio.sleep(0.5) + self.proc = proc + + def terminate(self): + if self.proc: + self.proc.terminate() + + +@pytest.fixture +def worker(): + worker = WorkerFixture() + yield worker + worker.terminate() diff --git a/src/aiotaskq/tests/test_cli.py b/src/aiotaskq/tests/test_cli.py index ce804c9..e0f980c 100644 --- a/src/aiotaskq/tests/test_cli.py +++ b/src/aiotaskq/tests/test_cli.py @@ -1,7 +1,5 @@ import multiprocessing import os -import subprocess -import time def test_root_show_proper_help_message(): @@ -39,88 +37,10 @@ def test_worker_show_proper_help_message(): " APP [required]\n" "\n" "Options:\n" - f" --concurrency INTEGER [default: {multiprocessing.cpu_count()}]\n" - " --help Show this message and exit.\n" + f" --concurrency INTEGER [default: {multiprocessing.cpu_count()}]\n" + " --poll-interval-s FLOAT [default: 0.01]\n" + " --concurrency-type [multiprocessing]\n" + " [default: multiprocessing]\n" + " --help Show this message and exit.\n" ) assert output == output_expected - - -def test_worker_concurrency_starts_child_workers(): - """ - Assert that when --concurrency N option is provided, N child processes will be spawn. - """ - # Given that the worker cli is run with "--concurrency 4" option - concurrency = 4 - bash_command = [ - "aiotaskq", - "worker", - "--concurrency", - str(concurrency), - "aiotaskq.tests.apps.simple_app", - ] - with subprocess.Popen(bash_command) as worker_cli_process: - worker_cli_pid = worker_cli_process.pid - # Once we've given enough time for child worker processes to be spawned - time.sleep(0.5) - # Then the number of child worker processes spawned should be the same as requested - with os.popen(f"pgrep -P {worker_cli_pid} | wc -l") as child_process_counter: - child_process_count = int(child_process_counter.read()) - assert child_process_count == concurrency - worker_cli_process.terminate() - - -def test_worker_concurrency_starts_child_workers_with_default_concurrency(): - """ - Assert that when --concurrency is NOT provided, N child processes will be spawned, N=cpu cores. - """ - # Given that the machine has the following cpu count - cpu_count_on_machine = multiprocessing.cpu_count() - # When the worker cli is run without "--concurrency" option - bash_command = ["aiotaskq", "worker", "aiotaskq.tests.apps.simple_app"] - with subprocess.Popen(bash_command) as worker_cli_process: - worker_cli_pid = worker_cli_process.pid - # Once we've given enough time for child worker processes to be spawned - time.sleep(0.5) - # Then the number of child worker processes spawned should be the same - # as cpu core on machine - with os.popen(f"pgrep -P {worker_cli_pid} | wc -l") as child_process_counter: - child_process_count = int(child_process_counter.read()) - assert child_process_count == cpu_count_on_machine - worker_cli_process.terminate() - - -def test_worker_incorrect_app(): - # Given that the worker is started with an incorrect app name - incorrect_app_name = "some.incorrect.app.name" - bash_command = ["aiotaskq", "worker", incorrect_app_name] - worker_cli_process = subprocess.Popen(args=bash_command, stdout=subprocess.PIPE) - with WrapClose(proc=worker_cli_process) as worker_cli_process_pipe: - # Then the worker process should print error message - output = str(worker_cli_process_pipe.read()) - output_expected = ( - "Error at argument `--app_import_path some.incorrect.app.name`: " - '"some.incorrect.app.name" is not a path to a valid Python module' - ) - assert output_expected in output - # And exit immediately with an error exit code - assert worker_cli_process.returncode == 1 - - -class WrapClose: - def __init__(self, proc: subprocess.Popen): - self._proc = proc - self._stdout = proc.stdout - - def __enter__(self): - return self - - def __exit__(self, exc_type, value, traceback): - self.close() - - def __getattr__(self, name): - print(f"__getattr__({name})") - return getattr(self._stdout, name) - - def close(self): - self._stdout.close() - self._proc.wait() diff --git a/src/aiotaskq/tests/test_concurrency_manager.py b/src/aiotaskq/tests/test_concurrency_manager.py new file mode 100644 index 0000000..85bbe6b --- /dev/null +++ b/src/aiotaskq/tests/test_concurrency_manager.py @@ -0,0 +1,20 @@ +from aiotaskq.exceptions import ConcurrencyTypeNotSupported +from aiotaskq.concurrency_manager import ConcurrencyManager + + +def test_unsupported_concurrency_type(): + # Given an incorrect concurrency type + incorrect_concurrency_type = "some-incorrect-concurrency-type" + + # When getting the concurrency manager + error = None + try: + ConcurrencyManager._instance = None + ConcurrencyManager.get(concurrency_type=incorrect_concurrency_type, concurrency=4) + except ConcurrencyTypeNotSupported as e: + error = e + finally: + # Then a helpful error should be raised + assert ( + str(error) == 'Concurrency type "some-incorrect-concurrency-type" is not yet supported.' + ) diff --git a/src/aiotaskq/tests/test_integration.py b/src/aiotaskq/tests/test_integration.py index a15be0a..2f46df0 100644 --- a/src/aiotaskq/tests/test_integration.py +++ b/src/aiotaskq/tests/test_integration.py @@ -1,32 +1,23 @@ -import asyncio -import subprocess -from typing import Any - import pytest from aiotaskq.main import Task +from aiotaskq.tests.conftest import WorkerFixture from aiotaskq.tests.apps import simple_app @pytest.mark.asyncio -async def test_sync_and_async_parity__simple_app(): +async def test_sync_and_async_parity__simple_app(worker: WorkerFixture): # Given a simple app running as a worker app = simple_app - bash_command = ["aiotaskq", "worker", app.__name__] - with subprocess.Popen(bash_command) as worker_cli_process: - # Once worker process is ready - await asyncio.sleep(0.5) - # Then there should be parity between sync and async call of the tasks - tests: list[tuple[Task, Any, Any]] = [ - (simple_app.add, tuple(), {"x": 41, "y": 1}), - (simple_app.power, (2,), {"b": 64}), - (simple_app.join, ([2021, 2, 20],), {}), - (simple_app.some_task, (21,), {}), - ] - try: - for task, args, kwargs in tests: - sync_result = task(*args, **kwargs) - async_result = await task.apply_async(*args, **kwargs) - assert async_result == sync_result, f"{async_result} != {sync_result}" - finally: - worker_cli_process.terminate() + await worker.start(app=app.__name__, concurrency=8) + # Then there should be parity between sync and async call of the tasks + tests: list[tuple[Task, tuple, dict]] = [ + (simple_app.add, tuple(), {"x": 41, "y": 1}), + (simple_app.power, (2,), {"b": 64}), + (simple_app.join, ([2021, 2, 20],), {}), + (simple_app.some_task, (21,), {}), + ] + for task, args, kwargs in tests: + sync_result = task(*args, **kwargs) + async_result = await task.apply_async(*args, **kwargs) + assert async_result == sync_result, f"{async_result} != {sync_result}" diff --git a/src/aiotaskq/tests/test_pubsub.py b/src/aiotaskq/tests/test_pubsub.py new file mode 100644 index 0000000..88ac140 --- /dev/null +++ b/src/aiotaskq/tests/test_pubsub.py @@ -0,0 +1,18 @@ +from aiotaskq.exceptions import UrlNotSupported +from aiotaskq.pubsub import PubSub + + +def test_invalid_url(): + # Given an unsupported pubsub url + unsupported_pubsub_url = "cache+memcached://127.0.0.1:11211/" + + # When getting a pubsub instance using the url + error = None + try: + PubSub._instance = None + PubSub.get(url=unsupported_pubsub_url, poll_interval_s=1.0) + except UrlNotSupported as e: + error = e + finally: + # Then a helpful error should be raised + assert str(error) == 'Url "cache+memcached://127.0.0.1:11211/" is currently not supported.' diff --git a/src/aiotaskq/tests/test_worker.py b/src/aiotaskq/tests/test_worker.py new file mode 100644 index 0000000..7523fd0 --- /dev/null +++ b/src/aiotaskq/tests/test_worker.py @@ -0,0 +1,163 @@ +from asyncore import poll +import multiprocessing +import os +import signal +import subprocess +from typing import TYPE_CHECKING + +import pytest +from aiotaskq.interfaces import ConcurrencyType + +from aiotaskq.worker import run_worker_forever, validate_input + +if TYPE_CHECKING: # pragma: no cover + from aiotaskq.tests.conftest import WorkerFixture + + +@pytest.mark.asyncio +async def test_concurrency_starts_child_workers(worker: "WorkerFixture"): + """ + Assert that when --concurrency N option is provided, N child processes will be spawn. + """ + # Given that the worker cli is run with "--concurrency 4" option + concurrency = 4 + await worker.start(app="aiotaskq.tests.apps.simple_app", concurrency=concurrency) + + # Then the number of child worker processes spawned should be the same as requested + with os.popen(f"pgrep -P {worker.proc.pid} | wc -l") as child_process_counter: + child_process_count = int(child_process_counter.read()) + assert child_process_count == concurrency, f"{child_process_count} != {concurrency}" + + +@pytest.mark.asyncio +async def test_concurrency_starts_child_workers_with_default_concurrency( + worker: "WorkerFixture", +): + """ + Assert that when --concurrency is NOT provided, N child processes will be spawned, N=cpu cores. + """ + # Given that the machine has the following cpu count + cpu_count_on_machine = multiprocessing.cpu_count() + + # When the worker cli is run without "--concurrency" option + await worker.start(app="aiotaskq.tests.apps.simple_app") + + # Then the number of child worker processes spawned should be the same + # as cpu core on machine + with os.popen(f"pgrep -P {worker.proc.pid} | wc -l") as child_process_counter: + child_process_count = int(child_process_counter.read()) + assert child_process_count == cpu_count_on_machine + + +def test_incorrect_app(): + # Given that the worker is started with an incorrect app name + incorrect_app_name = "some.incorrect.app.name" + bash_command = ["aiotaskq", "worker", incorrect_app_name] + worker_cli_process = subprocess.Popen(args=bash_command, stdout=subprocess.PIPE) + with WrapClose(proc=worker_cli_process) as worker_cli_process_pipe: + # Then the worker process should print error message + output = str(worker_cli_process_pipe.read()) + output_expected = ( + "Error at argument `--app_import_path some.incorrect.app.name`: " + '"some.incorrect.app.name" is not a path to a valid Python module' + ) + assert output_expected in output + # And exit immediately with an error exit code + assert worker_cli_process.returncode == 1 + + +def test_validate_input(): + # Given an incorrect app path + incorrect_app_import_path = "some.incorrect.app.name" + + # When validating with `validate_input` + error_msg = validate_input(app_import_path=incorrect_app_import_path) + + # Then a descriptive error_message should be returned + assert error_msg == ( + "Error at argument `--app_import_path some.incorrect.app.name`:" + ' "some.incorrect.app.name" is not a path to a valid Python module' + ) + + +@pytest.mark.asyncio +async def test_run_worker__incorrect_app_name(worker: "WorkerFixture"): + # Given a worker being started with an incorrect app path + await worker.start( + app="some.incorrect.app.name", + concurrency=2, + concurrency_type=ConcurrencyType.MULTIPROCESSING, + poll_interval_s=1.0, + ) + + # Then the worker should exit immediately with an error exit code + assert worker.proc.exitcode == 1 + + +@pytest.mark.asyncio +async def test_handle_keyboard_interrupt(worker: "WorkerFixture"): + # Given a running worker with some child processes + concurrency = 4 + await worker.start("aiotaskq.tests.apps.simple_app", concurrency=concurrency) + bash_command = ( + "pidof $(which python) " # Get pids of all python processes + "| tr ' ' '\\n' " # Break single line into multiple lines for easier processing + f"| grep -v {os.getpid()} " # Filter out this current process + "| wc -l" # Count the number of pids + ) + with os.popen(bash_command) as process_counter: + process_count_before = int(process_counter.read()) + assert process_count_before == concurrency + 1 # workers (concurrent) + worker manager (1) + + # When SIGINT signal (Keyboard Interrupt aka Ctrl-C) is sent to the worker process + os.kill(worker.proc.pid, signal.SIGINT) + + # Then all child processes should be terminated + with os.popen(bash_command) as process_counter: + process_count_after = int(process_counter.read()) + assert process_count_after == process_count_before - (concurrency + 1) + + +@pytest.mark.asyncio +async def test_handle_termination_signal(worker: "WorkerFixture"): + # Given a running worker with some child processes + concurrency = 4 + await worker.start("aiotaskq.tests.apps.simple_app", concurrency=concurrency) + bash_command = ( + "pidof $(which python) " # Get pids of all python processes + "| tr ' ' '\\n' " # Break single line into multiple lines for easier processing + f"| grep -v {os.getpid()} " # Filter out this current process + "| wc -l" # Count the number of pids + ) + with os.popen(bash_command) as process_counter: + process_count_before = int(process_counter.read()) + assert process_count_before == concurrency + 1 # workers (concurrent) + worker manager (1) + + # When SIGTERM signal (Termination signal) is sent to the worker process + os.kill(worker.proc.pid, signal.SIGTERM) + + # Then all child processes should be terminated + with os.popen(bash_command) as process_counter: + process_count_after = int(process_counter.read()) + assert process_count_after == process_count_before - (concurrency + 1) + + +class WrapClose: + """Wrap a process and provide context manager support for closing all resources properly.""" + + def __init__(self, proc: subprocess.Popen): + self._proc = proc + self._stdout = proc.stdout + + def __enter__(self): + return self + + def __exit__(self, exc_type, value, traceback): + self.close() + + def __getattr__(self, name): + return getattr(self._stdout, name) + + def close(self): + self._stdout.close() + self._proc.wait() diff --git a/src/aiotaskq/worker.py b/src/aiotaskq/worker.py index 3093b8e..9ea9521 100755 --- a/src/aiotaskq/worker.py +++ b/src/aiotaskq/worker.py @@ -1,180 +1,205 @@ """Module to define the main logic for the worker.""" +from abc import ABC, abstractmethod import asyncio +from functools import cached_property import importlib import json import logging import multiprocessing import os +import signal import sys -import types import typing as t +import types -import aioredis - -from aiotaskq.constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL - +from .concurrency_manager import ConcurrencyManager +from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL +from .interfaces import ConcurrencyType, IConcurrencyManager, IPubSub +from .pubsub import PubSub logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +class BaseWorker(ABC): + app: types.ModuleType + pubsub: IPubSub + concurrency_manager: IConcurrencyManager + + def __init__(self, app_import_path: str): + self.app = importlib.import_module(app_import_path) + + def run_forever(self) -> None: + """Run the worker forever in a loop, after running some preparation logic (in _pre_run).""" + + async def _start(): + await self._pre_run() + await self._main_loop() + + asyncio.run(_start()) + + @abstractmethod + async def _pre_run(self): + """Define any logic to run before running the _main_loop.""" + + @abstractmethod + async def _main_loop(self): + """Define the logic for the main loop.""" + + @cached_property + def _logger(self): + return logging.getLogger(f"[{self._pid}] [{self.__class__.__qualname__}]") + + @cached_property + def _pid(self) -> int: + return os.getpid() + + @staticmethod + def _get_child_worker_tasks_channel(pid: int) -> str: + return f"{TASKS_CHANNEL}:{pid}" + + class Defaults: """Store default constants for the aiotaskq.worker module.""" - @classmethod - @property - def poll_interval_s(cls) -> float: - """Return the time in seconds to poll for next task.""" - return 0.1 - @classmethod @property def concurrency(cls) -> int: """Return the number of worker process to spawn.""" return multiprocessing.cpu_count() + @classmethod + @property + def concurrency_type(cls) -> str: + return ConcurrencyType.MULTIPROCESSING.value -async def worker( - app_import_path: str, - concurrency: int, - poll_interval_s: t.Optional[float] = Defaults.poll_interval_s, -): - """Main loop for worker to poll for next task and execute them.""" + @classmethod + @property + def poll_interval_s(cls) -> float: + """Return the time in seconds to poll for next task.""" + return 0.01 + + +class WorkerManager(BaseWorker): + def __init__( + self, + app_import_path: str, + concurrency: int, + concurrency_type: ConcurrencyType, + poll_interval_s: float, + ) -> None: + self.pubsub: IPubSub = PubSub.get(url=REDIS_URL, poll_interval_s=poll_interval_s) + self.concurrency_manager: IConcurrencyManager = ConcurrencyManager.get( + concurrency_type=concurrency_type, + concurrency=concurrency, + ) + self._poll_interval_s = poll_interval_s + super().__init__(app_import_path=app_import_path) + + async def _pre_run(self): + self._logger.info("Starting %s back workers", self.concurrency_manager.concurrency) + self._start_grunt_workers() + loop = asyncio.get_event_loop() + loop.add_signal_handler(signal.SIGTERM, self._sigterm_handler) + loop.add_signal_handler(signal.SIGINT, self._sigint_handler) + + def _sigterm_handler(self): + # pylint: disable=no-member + self._logger.debug("Handling signal %s (%s)", signal.SIGTERM.value, signal.SIGTERM.name) + self._handle_murder_signals() + + def _sigint_handler(self): + # pylint: disable=no-member + self._logger.debug("Handling signal %s (%s)", signal.SIGINT.value, signal.SIGINT.name) + self._handle_murder_signals() + + def _handle_murder_signals(self): + """Terminate (send TERM) to child processes upon receiving murder signals (TERM, INT).""" + for task in asyncio.tasks.all_tasks(): + task.cancel() + self.concurrency_manager.terminate() + + async def _main_loop(self): + self._logger.info("Started main loop") + + async with self.pubsub as pubsub: + counter = -1 + grunt_worker_pids: list[int] = list(self.concurrency_manager.processes.keys()) + await pubsub.subscribe(TASKS_CHANNEL) + while True: + self._logger.debug("Polling for a new task until it's available") + message = await pubsub.poll() - err_msg: str = _validate_input(app_import_path=app_import_path) - if err_msg: - print(err_msg) - sys.exit(1) + # A new task is now available + # Pass the task to one of the workers worker + counter = (counter + 1) % len(self.concurrency_manager.processes) + selected_grunt_worker_pid = grunt_worker_pids[counter] + channel: str = self._get_child_worker_tasks_channel(pid=selected_grunt_worker_pid) + self._logger.debug( + "[%s] Passing task to %sth child worker [message=%s, channel=%s]", + *(self._pid, counter, message, channel), + ) + await pubsub.publish(channel=channel, message=message["data"]) - pid: int = os.getpid() - logger.info( - "[pid=%s] aiotaskq worker \n" - "\tversion: %s\n" - "\tpoll interval (seconds): %s\n" - "\tredis url: %s\n" - "\tconcurrency: %s\n", - *(pid, "1.0.0", poll_interval_s, REDIS_URL, concurrency), - ) - - # Ensure child worker processes log to parent's stderr - multiprocessing.log_to_stderr(logging.DEBUG) - - # Start child worker processes in background - child_worker_processes: list["multiprocessing.Process"] = [ - multiprocessing.Process(target=_worker, args=(app_import_path, poll_interval_s)) - for _ in range(concurrency) - ] - for proc in child_worker_processes: - proc.start() - child_worker_pids = [c.pid for c in child_worker_processes] - - # Main worker accepts new task and pass it on to one of child workers - logger.info( - "[%s] Forked %s child worker processes: [pids=%s]", - *(pid, len(child_worker_processes), child_worker_pids), - ) - redis_client: aioredis.Redis = aioredis.from_url(REDIS_URL) - async with redis_client.pubsub() as pubsub: - await pubsub.subscribe(TASKS_CHANNEL) - counter = 0 - while True: - # Poll for a new task until it's available - message: t.Optional[str] = None - while message is None: - message = await pubsub.get_message(ignore_subscribe_messages=True) - await asyncio.sleep(poll_interval_s) - - # A new task is now available - # Pass the task to one of the workers worker - counter = (counter + 1) % len(child_worker_processes) - selected_child_worker = child_worker_processes[counter] - channel: str = _get_child_worker_tasks_channel(pid=selected_child_worker.pid) - logger.debug( - "[%s] Passing task to %sth child worker [message=%s, channel=%s]", - *(pid, counter, message, channel), + def _start_grunt_workers(self): + def _run_grunt_worker_forever(): + grunt_worker = GruntWorker( + app_import_path=self.app.__name__, + poll_interval_s=self._poll_interval_s, ) - await redis_client.publish(channel, message=message["data"]) + grunt_worker.run_forever() + self.concurrency_manager.start(func=_run_grunt_worker_forever) -def _worker( - app_import_path: str, - poll_interval_s: t.Optional[float] = Defaults.poll_interval_s, -): - pid: int = os.getpid() - channel: str = _get_child_worker_tasks_channel(pid=pid) - app: types.ModuleType = importlib.import_module(app_import_path) - logger.info("Child worker process [pid=%s]", pid) - - async def _main_loop(): - logger.info("[%s] Main loop", pid) - redis_client: aioredis.Redis = aioredis.from_url(REDIS_URL) - async with redis_client.pubsub() as pubsub: - await pubsub.subscribe(channel) +class GruntWorker(BaseWorker): + def __init__(self, app_import_path: str, poll_interval_s: float): + self.pubsub: IPubSub = PubSub.get(url=REDIS_URL, poll_interval_s=poll_interval_s) + super().__init__(app_import_path=app_import_path) + + async def _pre_run(self): + pass + + async def _main_loop(self): + self._logger.debug("Started main loop") + channel: str = self._get_child_worker_tasks_channel(pid=self._pid) + + async with self.pubsub as pubsub: + await pubsub.subscribe(channel=channel) while True: - # Poll for a new task until it's available - logger.debug( - "[%s] Waiting for new tasks from main worker [channel=%s]", - *(pid, channel), + self._logger.debug( + "[%s] Polling for a new task from manager until it's available [channel=%s]", + *(self._pid, channel), ) - message = None - while message is None: - message = await pubsub.get_message(ignore_subscribe_messages=True) - await asyncio.sleep(poll_interval_s) + message = await pubsub.poll() # A new task is now available - logger.debug( + self._logger.debug( "[%s] Received task to from main worker [message=%s, channel=%s]", - *(pid, message, channel), + *(self._pid, message, channel), ) task_info = json.loads(message["data"]) task_args = task_info["args"] task_kwargs = task_info["kwargs"] task_id: str = task_info["task_id"] task_func_name: str = task_id.split(":")[0].split(".")[-1] - task_func = getattr(app, task_func_name) + task_func = getattr(self.app, task_func_name) # Execute the task - logger.debug( + self._logger.debug( "[%s] Executing task %s(*%s, **%s)", - *(pid, task_id, task_args, task_kwargs), + *(self._pid, task_id, task_args, task_kwargs), ) task_result = task_func(*task_args, **task_kwargs) # Publish the task return value task_result = json.dumps(task_result) - await redis_client.publish( - RESULTS_CHANNEL_TEMPLATE.format(task_id=task_id), - message=task_result, - ) - - asyncio.run(main=_main_loop()) - - -async def _wait_for_child_workers_ready( - publisher: aioredis.Redis, - child_worker_pids: list[int], - poll_interval_s: int, -) -> None: - while True: - # Keep check until all child workers are ready - ready_statuses_coro: list[t.Coroutine] = [ - publisher.pubsub_numsub(_get_child_worker_tasks_channel(pid=pid)) - for pid in child_worker_pids - ] - ready_statuses: list[list[tuple[bytes, int]]] = await asyncio.gather(*ready_statuses_coro) - if all(is_ready for (_, is_ready), *_ in ready_statuses): - break - # Continue waiting and checking - await asyncio.sleep(poll_interval_s) - + result_channel = RESULTS_CHANNEL_TEMPLATE.format(task_id=task_id) + await pubsub.publish(channel=result_channel, message=task_result) -def _get_child_worker_tasks_channel(pid: int) -> str: - return f"{TASKS_CHANNEL}:{pid}" - -def _validate_input(app_import_path: str) -> t.Optional[str]: +def validate_input(app_import_path: str) -> t.Optional[str]: try: importlib.import_module(app_import_path) except ModuleNotFoundError: @@ -184,3 +209,26 @@ def _validate_input(app_import_path: str) -> t.Optional[str]: ) return None + + +def run_worker_forever( + app_import_path: str, + concurrency: int, + concurrency_type: ConcurrencyType, + poll_interval_s: float, +) -> None: + err_msg: t.Optional[str] = validate_input(app_import_path=app_import_path) + if err_msg: + print(err_msg) + sys.exit(1) + + try: + worker_manager = WorkerManager( + app_import_path=app_import_path, + concurrency=concurrency, + concurrency_type=concurrency_type, + poll_interval_s=poll_interval_s, + ) + worker_manager.run_forever() + except asyncio.CancelledError: + pass diff --git a/test.sh b/test.sh index 01ea116..a6c6ac2 100755 --- a/test.sh +++ b/test.sh @@ -3,9 +3,13 @@ pip install pytest pip install pytest-asyncio pip install coverage +coverage erase + if [ -z $1 ]; then coverage run -m pytest -v -s else coverage run -m pytest -v -s -k $1 fi + +coverage combine