Skip to content

Commit

Permalink
add shield() in aexit
Browse files Browse the repository at this point in the history
Added ``asyncio.shield()`` to the connection and session release process
specifically within the ``__aexit__()`` context manager exit, when using
:class:`.AsyncConnection` or :class:`.AsyncSession` as a context manager
that releases the object when the context manager is complete. This appears
to help with task cancellation when using alternate concurrency libraries
such as ``anyio``, ``uvloop`` that otherwise don't provide an async context
for the connection pool to release the connection properly during task
cancellation.

Fixes: #8145
Change-Id: I0b1ea9c3a22a18619341cbb8591225fcd339042c
  • Loading branch information
CaselIT authored and zzzeek committed Jul 18, 2022
1 parent c3102b8 commit 1acaf0b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
14 changes: 14 additions & 0 deletions doc/build/changelog/unreleased_14/8145.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. change::
:tags: bug, asyncio
:tickets: 8145

Added ``asyncio.shield()`` to the connection and session release process
specifically within the ``__aexit__()`` context manager exit, when using
:class:`.AsyncConnection` or :class:`.AsyncSession` as a context manager
that releases the object when the context manager is complete. This appears
to help with task cancellation when using alternate concurrency libraries
such as ``anyio``, ``uvloop`` that otherwise don't provide an async context
for the connection pool to release the connection properly during task
cancellation.


12 changes: 8 additions & 4 deletions lib/sqlalchemy/ext/asyncio/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations

import asyncio
from typing import Any
from typing import Dict
from typing import Generator
Expand Down Expand Up @@ -698,7 +699,7 @@ def __await__(self) -> Generator[Any, None, AsyncConnection]:
return self.start().__await__()

async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.close()
await asyncio.shield(self.close())

# START PROXY METHODS AsyncConnection

Expand Down Expand Up @@ -855,8 +856,11 @@ async def start(self, is_ctxmanager: bool = False) -> AsyncConnection:
async def __aexit__(
self, type_: Any, value: Any, traceback: Any
) -> None:
await self.transaction.__aexit__(type_, value, traceback)
await self.conn.close()
async def go() -> None:
await self.transaction.__aexit__(type_, value, traceback)
await self.conn.close()

await asyncio.shield(go())

def __init__(self, sync_engine: Engine):
if not sync_engine.dialect.is_async:
Expand Down Expand Up @@ -956,7 +960,7 @@ async def dispose(self, close: bool = True) -> None:
"""

return await greenlet_spawn(self.sync_engine.dispose, close=close)
await greenlet_spawn(self.sync_engine.dispose, close=close)

# START PROXY METHODS AsyncEngine

Expand Down
12 changes: 8 additions & 4 deletions lib/sqlalchemy/ext/asyncio/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations

import asyncio
from typing import Any
from typing import Dict
from typing import Generic
Expand Down Expand Up @@ -837,7 +838,7 @@ async def close(self) -> None:
:meth:`_asyncio.AsyncSession.close`
"""
return await greenlet_spawn(self.sync_session.close)
await greenlet_spawn(self.sync_session.close)

async def invalidate(self) -> None:
"""Close this Session, using connection invalidation.
Expand All @@ -855,7 +856,7 @@ async def __aenter__(self: _AS) -> _AS:
return self

async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.close()
await asyncio.shield(self.close())

def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]:
return _AsyncSessionContextManager(self)
Expand Down Expand Up @@ -1516,8 +1517,11 @@ async def __aenter__(self) -> _AS:
return self.async_session

async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.trans.__aexit__(type_, value, traceback)
await self.async_session.__aexit__(type_, value, traceback)
async def go() -> None:
await self.trans.__aexit__(type_, value, traceback)
await self.async_session.__aexit__(type_, value, traceback)

await asyncio.shield(go())


class AsyncSessionTransaction(
Expand Down

0 comments on commit 1acaf0b

Please sign in to comment.