diff --git a/aiomonitor/monitor.py b/aiomonitor/monitor.py index 906a20b3..d930d086 100644 --- a/aiomonitor/monitor.py +++ b/aiomonitor/monitor.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextvars import functools import logging import os @@ -318,6 +319,9 @@ def _create_task( self, loop: asyncio.AbstractEventLoop, coro: Coroutine[Any, Any, T_co] | Generator[Any, None, T_co], + *, + name: str | None = None, + context: contextvars.Context | None = None, ) -> asyncio.Future[T_co]: assert loop is self._monitored_loop try: @@ -331,6 +335,8 @@ def _create_task( cancellation_chain_queue=self._cancellation_chain_queue.sync_q, persistent=persistent, loop=self._monitored_loop, + name=name, # since Python 3.8 + context=context, # since Python 3.11 ) task._orig_coro = cast(Coroutine[Any, Any, T_co], coro) self._created_tracebacks[task] = _extract_stack_from_frame(sys._getframe())[ diff --git a/tests/test_monitor.py b/tests/test_monitor.py index c97ebb7f..e8523f25 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import telnetlib import threading import time @@ -142,6 +143,32 @@ def test_basic_monitor(monitor, tn_client, loop): assert "No task 123" in resp +myvar = contextvars.ContextVar("myvar", default=42) + + +def test_monitor_task_factory(): + ctx = contextvars.Context() + # This context is bound at the outermost scope, + # and inside it the initial value of myvar is kept intact. + + async def do(): + await asyncio.sleep(0) + assert myvar.get() == 42 # we are referring the outer context + myself = asyncio.current_task() + assert myself is not None + assert myself.get_name() == "mytask" + + async def main(): + myvar.set(99) # override in the current task's context + loop = asyncio.get_running_loop() + with Monitor(loop, console_enabled=False, hook_task_factory=True): + t = asyncio.create_task(do(), name="mytask", context=ctx) + await t + assert myvar.get() == 99 + + asyncio.run(main()) + + def test_cancel_where_tasks(monitor, tn_client, loop): tn = tn_client