Skip to content

Commit

Permalink
fix: incorrect concurrent usage of connection and transaction (#546)
Browse files Browse the repository at this point in the history
* fix: incorrect concurrent usage of connection and transaction

* refactor: rename contextvar class attributes, add some explaination comments

* fix: contextvar.get takes no keyword arguments

* test: add concurrent task tests

* feat: use ContextVar[dict] to track connections and transactions per task

* test: check multiple databases in the same task use independant connections

* chore: changes for linting and typechecking

* chore: use typing.Tuple for lower python version compatibility

* docs: update comment on _connection_contextmap

* Update `Connection` and `Transaction` to be robust to concurrent use

* chore: remove optional annotation on asyncio.Task

* test: add new tests for upcoming contextvar inheritance/isolation and weakref cleanup

* feat: reimplement concurrency system with contextvar and weakmap

* chore: apply corrections from linters

* fix: quote WeakKeyDictionary typing for python<=3.7

* docs: add examples for async transaction context and nested transactions

* fix: remove connection inheritance, add more tests, update docs

Connections are once again stored as state on the Database instance,
keyed by the current asyncio.Task. Each task acquires it's own
connection, and a WeakKeyDictionary allows the connection to be
discarded if the owning task is garbage collected. TransactionBackends
are still stored as contextvars, and a connection must be explicitly
provided to descendant tasks if active transaction state is to be
inherited.

---------

Co-authored-by: Zanie <contact@zanie.dev>
  • Loading branch information
zevisert and zanieb authored Jul 25, 2023
1 parent c095428 commit 25fa295
Show file tree
Hide file tree
Showing 3 changed files with 521 additions and 35 deletions.
92 changes: 78 additions & 14 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import logging
import typing
import weakref
from contextvars import ContextVar
from types import TracebackType
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit
Expand All @@ -11,7 +12,7 @@
from sqlalchemy.sql import ClauseElement

from databases.importer import import_from_string
from databases.interfaces import DatabaseBackend, Record
from databases.interfaces import DatabaseBackend, Record, TransactionBackend

try: # pragma: no cover
import click
Expand All @@ -35,6 +36,11 @@
logger = logging.getLogger("databases")


_ACTIVE_TRANSACTIONS: ContextVar[
typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"]
] = ContextVar("databases:active_transactions", default=None)


class Database:
SUPPORTED_BACKENDS = {
"postgresql": "databases.backends.postgres:PostgresBackend",
Expand All @@ -45,6 +51,8 @@ class Database:
"sqlite": "databases.backends.sqlite:SQLiteBackend",
}

_connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"

def __init__(
self,
url: typing.Union[str, "DatabaseURL"],
Expand All @@ -55,6 +63,7 @@ def __init__(
self.url = DatabaseURL(url)
self.options = options
self.is_connected = False
self._connection_map = weakref.WeakKeyDictionary()

self._force_rollback = force_rollback

Expand All @@ -63,14 +72,35 @@ def __init__(
assert issubclass(backend_cls, DatabaseBackend)
self._backend = backend_cls(self.url, **self.options)

# Connections are stored as task-local state.
self._connection_context: ContextVar = ContextVar("connection_context")

# When `force_rollback=True` is used, we use a single global
# connection, within a transaction that always rolls back.
self._global_connection: typing.Optional[Connection] = None
self._global_transaction: typing.Optional[Transaction] = None

@property
def _current_task(self) -> asyncio.Task:
task = asyncio.current_task()
if not task:
raise RuntimeError("No currently active asyncio.Task found")
return task

@property
def _connection(self) -> typing.Optional["Connection"]:
return self._connection_map.get(self._current_task)

@_connection.setter
def _connection(
self, connection: typing.Optional["Connection"]
) -> typing.Optional["Connection"]:
task = self._current_task

if connection is None:
self._connection_map.pop(task, None)
else:
self._connection_map[task] = connection

return self._connection

async def connect(self) -> None:
"""
Establish the connection pool.
Expand All @@ -89,7 +119,7 @@ async def connect(self) -> None:
assert self._global_connection is None
assert self._global_transaction is None

self._global_connection = Connection(self._backend)
self._global_connection = Connection(self, self._backend)
self._global_transaction = self._global_connection.transaction(
force_rollback=True
)
Expand All @@ -113,7 +143,7 @@ async def disconnect(self) -> None:
self._global_transaction = None
self._global_connection = None
else:
self._connection_context = ContextVar("connection_context")
self._connection = None

await self._backend.disconnect()
logger.info(
Expand Down Expand Up @@ -187,12 +217,10 @@ def connection(self) -> "Connection":
if self._global_connection is not None:
return self._global_connection

try:
return self._connection_context.get()
except LookupError:
connection = Connection(self._backend)
self._connection_context.set(connection)
return connection
if not self._connection:
self._connection = Connection(self, self._backend)

return self._connection

def transaction(
self, *, force_rollback: bool = False, **kwargs: typing.Any
Expand All @@ -215,7 +243,8 @@ def _get_backend(self) -> str:


class Connection:
def __init__(self, backend: DatabaseBackend) -> None:
def __init__(self, database: Database, backend: DatabaseBackend) -> None:
self._database = database
self._backend = backend

self._connection_lock = asyncio.Lock()
Expand Down Expand Up @@ -249,6 +278,7 @@ async def __aexit__(
self._connection_counter -= 1
if self._connection_counter == 0:
await self._connection.release()
self._database._connection = None

async def fetch_all(
self,
Expand Down Expand Up @@ -345,6 +375,37 @@ def __init__(
self._force_rollback = force_rollback
self._extra_options = kwargs

@property
def _connection(self) -> "Connection":
# Returns the same connection if called multiple times
return self._connection_callable()

@property
def _transaction(self) -> typing.Optional["TransactionBackend"]:
transactions = _ACTIVE_TRANSACTIONS.get()
if transactions is None:
return None

return transactions.get(self, None)

@_transaction.setter
def _transaction(
self, transaction: typing.Optional["TransactionBackend"]
) -> typing.Optional["TransactionBackend"]:
transactions = _ACTIVE_TRANSACTIONS.get()
if transactions is None:
transactions = weakref.WeakKeyDictionary()
else:
transactions = transactions.copy()

if transaction is None:
transactions.pop(self, None)
else:
transactions[self] = transaction

_ACTIVE_TRANSACTIONS.set(transactions)
return transactions.get(self, None)

async def __aenter__(self) -> "Transaction":
"""
Called when entering `async with database.transaction()`
Expand Down Expand Up @@ -385,7 +446,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return wrapper # type: ignore

async def start(self) -> "Transaction":
self._connection = self._connection_callable()
self._transaction = self._connection._connection.transaction()

async with self._connection._transaction_lock:
Expand All @@ -401,15 +461,19 @@ async def commit(self) -> None:
async with self._connection._transaction_lock:
assert self._connection._transaction_stack[-1] is self
self._connection._transaction_stack.pop()
assert self._transaction is not None
await self._transaction.commit()
await self._connection.__aexit__()
self._transaction = None

async def rollback(self) -> None:
async with self._connection._transaction_lock:
assert self._connection._transaction_stack[-1] is self
self._connection._transaction_stack.pop()
assert self._transaction is not None
await self._transaction.rollback()
await self._connection.__aexit__()
self._transaction = None


class _EmptyNetloc(str):
Expand Down
54 changes: 50 additions & 4 deletions docs/connections_and_transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints.

## Connecting and disconnecting

You can control the database connect/disconnect, by using it as a async context manager.
You can control the database connection pool with an async context manager:

```python
async with Database(DATABASE_URL) as database:
...
```

Or by using explicit connection and disconnection:
Or by using the explicit `.connect()` and `.disconnect()` methods:

```python
database = Database(DATABASE_URL)
Expand All @@ -23,6 +23,8 @@ await database.connect()
await database.disconnect()
```

Connections within this connection pool are acquired for each new `asyncio.Task`.

If you're integrating against a web framework, then you'll probably want
to hook into framework startup or shutdown events. For example, with
[Starlette][starlette] you would use the following:
Expand Down Expand Up @@ -67,6 +69,7 @@ A transaction can be acquired from the database connection pool:
async with database.transaction():
...
```

It can also be acquired from a specific database connection:

```python
Expand Down Expand Up @@ -95,8 +98,51 @@ async def create_users(request):
...
```

Transaction blocks are managed as task-local state. Nested transactions
are fully supported, and are implemented using database savepoints.
Transaction state is tied to the connection used in the currently executing asynchronous task.
If you would like to influence an active transaction from another task, the connection must be
shared. This state is _inherited_ by tasks that are share the same connection:

```python
async def add_excitement(connnection: databases.core.Connection, id: int):
await connection.execute(
"UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id",
{"id": id}
)


async with Database(database_url) as database:
async with database.transaction():
# This note won't exist until the transaction closes...
await database.execute(
"INSERT INTO notes(id, text) values (1, 'databases is cool')"
)
# ...but child tasks can use this connection now!
await asyncio.create_task(add_excitement(database.connection(), id=1))

await database.fetch_val("SELECT text FROM notes WHERE id=1")
# ^ returns: "databases is cool!!!"
```

Nested transactions are fully supported, and are implemented using database savepoints:

```python
async with databases.Database(database_url) as db:
async with db.transaction() as outer:
# Do something in the outer transaction
...

# Suppress to prevent influence on the outer transaction
with contextlib.suppress(ValueError):
async with db.transaction():
# Do something in the inner transaction
...

raise ValueError('Abort the inner transaction')

# Observe the results of the outer transaction,
# without effects from the inner transaction.
await db.fetch_all('SELECT * FROM ...')
```

Transaction isolation-level can be specified if the driver backend supports that:

Expand Down
Loading

0 comments on commit 25fa295

Please sign in to comment.