From 1a4f2b577650753b4e3af4310daf4858b6f9c896 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 4 Oct 2024 09:38:55 +0100 Subject: [PATCH] gh-124958: fix asyncio.TaskGroup and _PyFuture refcycles --- Lib/asyncio/futures.py | 6 +- Lib/asyncio/taskgroups.py | 161 ++++++++++++----------- Lib/test/test_asyncio/test_futures.py | 22 ++++ Lib/test/test_asyncio/test_taskgroups.py | 95 ++++++++++++- 4 files changed, 203 insertions(+), 81 deletions(-) diff --git a/Lib/asyncio/futures.py b/Lib/asyncio/futures.py index 5f6fa2348726cf..c95fce035cd548 100644 --- a/Lib/asyncio/futures.py +++ b/Lib/asyncio/futures.py @@ -190,8 +190,7 @@ def result(self): the future is done and has an exception set, this exception is raised. """ if self._state == _CANCELLED: - exc = self._make_cancelled_error() - raise exc + raise self._make_cancelled_error() if self._state != _FINISHED: raise exceptions.InvalidStateError('Result is not ready.') self.__log_traceback = False @@ -208,8 +207,7 @@ def exception(self): InvalidStateError. """ if self._state == _CANCELLED: - exc = self._make_cancelled_error() - raise exc + raise self._make_cancelled_error() if self._state != _FINISHED: raise exceptions.InvalidStateError('Exception is not set.') self.__log_traceback = False diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index f2ee9648c43876..6061534dc7fdd2 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -66,94 +66,103 @@ async def __aenter__(self): return self async def __aexit__(self, et, exc, tb): - self._exiting = True + try: + self._exiting = True - if (exc is not None and - self._is_base_error(exc) and - self._base_error is None): - self._base_error = exc + if (exc is not None and + self._is_base_error(exc) and + self._base_error is None): + self._base_error = exc - if et is not None and issubclass(et, exceptions.CancelledError): - propagate_cancellation_error = exc - else: - propagate_cancellation_error = None + if et is not None and issubclass(et, exceptions.CancelledError): + propagate_cancellation_error = exc + else: + propagate_cancellation_error = None - if et is not None: - if not self._aborting: - # Our parent task is being cancelled: - # - # async with TaskGroup() as g: - # g.create_task(...) - # await ... # <- CancelledError - # - # or there's an exception in "async with": - # - # async with TaskGroup() as g: - # g.create_task(...) - # 1 / 0 - # - self._abort() - - # We use while-loop here because "self._on_completed_fut" - # can be cancelled multiple times if our parent task - # is being cancelled repeatedly (or even once, when - # our own cancellation is already in progress) - while self._tasks: - if self._on_completed_fut is None: - self._on_completed_fut = self._loop.create_future() - - try: - await self._on_completed_fut - except exceptions.CancelledError as ex: + if et is not None: if not self._aborting: # Our parent task is being cancelled: # - # async def wrapper(): - # async with TaskGroup() as g: - # g.create_task(foo) + # async with TaskGroup() as g: + # g.create_task(...) + # await ... # <- CancelledError + # + # or there's an exception in "async with": + # + # async with TaskGroup() as g: + # g.create_task(...) + # 1 / 0 # - # "wrapper" is being cancelled while "foo" is - # still running. - propagate_cancellation_error = ex self._abort() - self._on_completed_fut = None - - assert not self._tasks - - if self._base_error is not None: - raise self._base_error - - if self._parent_cancel_requested: - # If this flag is set we *must* call uncancel(). - if self._parent_task.uncancel() == 0: - # If there are no pending cancellations left, - # don't propagate CancelledError. - propagate_cancellation_error = None - - # Propagate CancelledError if there is one, except if there - # are other errors -- those have priority. - if propagate_cancellation_error is not None and not self._errors: - raise propagate_cancellation_error - - if et is not None and not issubclass(et, exceptions.CancelledError): - self._errors.append(exc) - - if self._errors: - # If the parent task is being cancelled from the outside - # of the taskgroup, un-cancel and re-cancel the parent task, - # which will keep the cancel count stable. - if self._parent_task.cancelling(): - self._parent_task.uncancel() - self._parent_task.cancel() + # We use while-loop here because "self._on_completed_fut" + # can be cancelled multiple times if our parent task + # is being cancelled repeatedly (or even once, when + # our own cancellation is already in progress) + while self._tasks: + if self._on_completed_fut is None: + self._on_completed_fut = self._loop.create_future() + + try: + await self._on_completed_fut + except exceptions.CancelledError as ex: + if not self._aborting: + # Our parent task is being cancelled: + # + # async def wrapper(): + # async with TaskGroup() as g: + # g.create_task(foo) + # + # "wrapper" is being cancelled while "foo" is + # still running. + propagate_cancellation_error = ex + self._abort() + + self._on_completed_fut = None + + assert not self._tasks + + if self._base_error is not None: + raise self._base_error + + if self._parent_cancel_requested: + # If this flag is set we *must* call uncancel(). + if self._parent_task.uncancel() == 0: + # If there are no pending cancellations left, + # don't propagate CancelledError. + propagate_cancellation_error = None + + # Propagate CancelledError if there is one, except if there + # are other errors -- those have priority. + if propagate_cancellation_error is not None and not self._errors: + raise propagate_cancellation_error + + if et is not None and not issubclass(et, exceptions.CancelledError): + self._errors.append(exc) + + if self._errors: + # If the parent task is being cancelled from the outside + # of the taskgroup, un-cancel and re-cancel the parent task, + # which will keep the cancel count stable. + if self._parent_task.cancelling(): + self._parent_task.uncancel() + self._parent_task.cancel() + raise BaseExceptionGroup( + 'unhandled errors in a TaskGroup', + self._errors, + ) from None + finally: # Exceptions are heavy objects that can have object # cycles (bad for GC); let's not keep a reference to # a bunch of them. - try: - me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors) - raise me from None - finally: - self._errors = None + propagate_cancellation_error = None + self._parent_task = None + self._errors = None + self._base_error = None + et = None + exc = None + tb = None + def create_task(self, coro, *, name=None, context=None): """Create a new task in this group and return it. diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py index 458b70451a306a..c566b28adb2408 100644 --- a/Lib/test/test_asyncio/test_futures.py +++ b/Lib/test/test_asyncio/test_futures.py @@ -659,6 +659,28 @@ def __del__(self): fut = self._new_future(loop=self.loop) fut.set_result(Evil()) + def test_future_cancelled_result_refcycles(self): + f = self._new_future(loop=self.loop) + f.cancel() + exc = None + try: + f.result() + except asyncio.CancelledError as e: + exc = e + self.assertIsNotNone(exc) + self.assertListEqual(gc.get_referrers(exc), []) + + def test_future_cancelled_exception_refcycles(self): + f = self._new_future(loop=self.loop) + f.cancel() + exc = None + try: + f.exception() + except asyncio.CancelledError as e: + exc = e + self.assertIsNotNone(exc) + self.assertListEqual(gc.get_referrers(exc), []) + @unittest.skipUnless(hasattr(futures, '_CFuture'), 'requires the C _asyncio module') diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index 4852536defc93d..b2cb456791d8d6 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -1,7 +1,7 @@ # Adapted with permission from the EdgeDB project; # license: PSFL. - +import gc import asyncio import contextvars import contextlib @@ -11,6 +11,10 @@ from test.test_asyncio.utils import await_without_task +if False: + asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = asyncio.tasks._PyTask + asyncio.Future = asyncio.futures._CFuture = asyncio.futures.Future = asyncio.futures._PyFuture + # To prevent a warning "test altered the execution environment" def tearDownModule(): @@ -899,6 +903,95 @@ async def outer(): await outer() + async def test_exception_refcycles_direct(self): + """Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup""" + tg = asyncio.TaskGroup() + exc = None + + class _Done(Exception): + pass + + try: + async with tg: + raise _Done + except ExceptionGroup as e: + exc = e + + self.assertIsNotNone(exc) + self.assertListEqual(gc.get_referrers(exc), []) + + + async def test_exception_refcycles_errors(self): + """Test that TaskGroup deletes self._errors, and __aexit__ args""" + tg = asyncio.TaskGroup() + exc = None + + class _Done(Exception): + pass + + try: + async with tg: + raise _Done + except* _Done as excs: + exc = excs.exceptions[0] + + self.assertIsInstance(exc, _Done) + self.assertListEqual(gc.get_referrers(exc), []) + + + async def test_exception_refcycles_parent_task(self): + """Test that TaskGroup deletes self._parent_task""" + tg = asyncio.TaskGroup() + exc = None + + class _Done(Exception): + pass + + async def coro_fn(): + async with tg: + raise _Done + + try: + async with asyncio.TaskGroup() as tg2: + tg2.create_task(coro_fn()) + except* _Done as excs: + exc = excs.exceptions[0].exceptions[0] + + self.assertIsInstance(exc, _Done) + self.assertListEqual(gc.get_referrers(exc), []) + + async def test_exception_refcycles_propagate_cancellation_error(self): + """Test that TaskGroup deletes propagate_cancellation_error""" + tg = asyncio.TaskGroup() + exc = None + + try: + async with asyncio.timeout(-1): + async with tg: + await asyncio.sleep(0) + except TimeoutError as e: + exc = e.__cause__ + + self.assertIsInstance(exc, asyncio.CancelledError) + self.assertListEqual(gc.get_referrers(exc), []) + + async def test_exception_refcycles_base_error(self): + """Test that TaskGroup deletes self._base_error""" + class MyKeyboardInterrupt(KeyboardInterrupt): + pass + + tg = asyncio.TaskGroup() + exc = None + + try: + async with tg: + raise MyKeyboardInterrupt + except MyKeyboardInterrupt as e: + exc = e + + self.assertIsNotNone(exc) + self.assertListEqual(gc.get_referrers(exc), []) + if __name__ == "__main__": unittest.main()