Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add ability to wait for locks and add locks to purge history / room deletion #15791

Merged
merged 11 commits into from
Jul 31, 2023
111 changes: 108 additions & 3 deletions synapse/handlers/worker_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@

import random
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
from typing import (
TYPE_CHECKING,
AsyncContextManager,
Collection,
Dict,
Optional,
Tuple,
Type,
Union,
)
from weakref import WeakSet

import attr
Expand Down Expand Up @@ -47,7 +56,9 @@ def __init__(self, hs: "HomeServer") -> None:

# Map from lock name/key to set of `WaitingLock` that are active for
# that lock.
self._locks: Dict[Tuple[str, str], WeakSet[WaitingLock]] = {}
self._locks: Dict[
Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]]
] = {}

self._clock.looping_call(self._cleanup_locks, 30_000)

Expand Down Expand Up @@ -103,6 +114,29 @@ def acquire_read_write_lock(

return lock

def acquire_multi_read_write_lock(
self,
lock_names: Collection[Tuple[str, str]],
*,
write: bool,
) -> "WaitingMultiLock":
"""Acquires multi read/write locks at once, returns a context manager
that will block until all the locks are acquired.
"""

lock = WaitingMultiLock(
lock_names=lock_names,
write=write,
reactor=self._reactor,
store=self._store,
handler=self,
)

for lock_name, lock_key in lock_names:
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)

return lock

def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
"""Notify that a lock has been released.

Expand All @@ -115,7 +149,7 @@ def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
def _on_lock_released(self, lock_name: str, lock_key: str) -> None:
"""Called when a lock has been released.

Wakes up any locks that might bew waiting on this.
Wakes up any locks that might be waiting on this.
"""
locks = self._locks.get((lock_name, lock_key))
if not locks:
Expand Down Expand Up @@ -202,3 +236,74 @@ def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
return next * random.uniform(0.9, 1.1)


@attr.s(auto_attribs=True, eq=False)
class WaitingMultiLock:
lock_names: Collection[Tuple[str, str]]

write: bool

reactor: IReactorTime
store: LockStore
handler: WorkerLocksHandler

deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)

_inner_lock_cm: Optional[AsyncContextManager] = None
_retry_interval: float = 0.1
_lock_span: "opentracing.Scope" = attr.Factory(
lambda: start_active_span("WaitingLock.lock")
)

async def __aenter__(self) -> None:
self._lock_span.__enter__()

with start_active_span("WaitingLock.waiting_for_lock"):
while self._inner_lock_cm is None:
self.deferred = defer.Deferred()

lock_cm = await self.store.try_acquire_multi_read_write_lock(
self.lock_names, write=self.write
)

if lock_cm:
self._inner_lock_cm = lock_cm
break

try:
with PreserveLoggingContext():
await timeout_deferred(
deferred=self.deferred,
timeout=self._get_next_retry_interval(),
reactor=self.reactor,
)
except Exception:
pass

assert self._inner_lock_cm
await self._inner_lock_cm.__aenter__()
return

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
assert self._inner_lock_cm

for lock_name, lock_key in self.lock_names:
self.handler.notify_lock_released(lock_name, lock_key)

try:
r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb)
finally:
self._lock_span.__exit__(exc_type, exc, tb)

return r

def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
return next * random.uniform(0.9, 1.1)
187 changes: 121 additions & 66 deletions synapse/storage/databases/main/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import AsyncExitStack
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type
from weakref import WeakValueDictionary

from twisted.internet.interfaces import IReactorCore
Expand Down Expand Up @@ -208,76 +209,85 @@ async def try_acquire_read_write_lock(
used (otherwise the lock will leak).
"""

try:
lock = await self.db_pool.runInteraction(
"try_acquire_read_write_lock",
self._try_acquire_read_write_lock_txn,
lock_name,
lock_key,
write,
)
except self.database_engine.module.IntegrityError:
return None

return lock

def _try_acquire_read_write_lock_txn(
self,
txn: LoggingTransaction,
lock_name: str,
lock_key: str,
write: bool,
) -> "Lock":
# We attempt to acquire the lock by inserting into
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
#
# Before that though we clear the table of any stale locks.

now = self._clock.time_msec()
token = random_string(6)

def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
# We attempt to acquire the lock by inserting into
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
#
# Before that though we clear the table of any stale locks.

delete_sql = """
DELETE FROM worker_read_write_locks
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
"""

insert_sql = """
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?, ?)
"""

if isinstance(self.database_engine, PostgresEngine):
# For Postgres we can send these queries at the same time.
txn.execute(
delete_sql + ";" + insert_sql,
(
# DELETE args
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
# UPSERT args
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
else:
# For SQLite these need to be two queries.
txn.execute(
delete_sql,
(
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
),
)
txn.execute(
insert_sql,
(
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
delete_sql = """
DELETE FROM worker_read_write_locks
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
"""

return
insert_sql = """
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?, ?)
"""

try:
await self.db_pool.runInteraction(
"try_acquire_read_write_lock",
_try_acquire_read_write_lock_txn,
if isinstance(self.database_engine, PostgresEngine):
# For Postgres we can send these queries at the same time.
txn.execute(
delete_sql + ";" + insert_sql,
(
# DELETE args
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
# UPSERT args
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
else:
# For SQLite these need to be two queries.
txn.execute(
delete_sql,
(
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
),
)
txn.execute(
insert_sql,
(
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
except self.database_engine.module.IntegrityError:
return None

lock = Lock(
self._reactor,
Expand All @@ -289,10 +299,55 @@ def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
token=token,
)

self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
def set_lock() -> None:
self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock

txn.call_after(set_lock)

return lock

async def try_acquire_multi_read_write_lock(
self,
lock_names: Collection[Tuple[str, str]],
write: bool,
) -> Optional[AsyncExitStack]:
"""Try to acquire a lock for the given name/key. Will return an async
context manager if the lock is successfully acquired, which *must* be
used (otherwise the lock will leak).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should describe the deadlock behaviour (well actually it looks like this one doesn't block, but for the other one it might be worth saying) and the atomicity behaviour: what happens if it can't get all the locks desired, but only some of them.
Does it release the locks it acquired half-way through or does it plough on?

"""
try:
locks = await self.db_pool.runInteraction(
"try_acquire_multi_read_write_lock",
self._try_acquire_multi_read_write_lock_txn,
lock_names,
write,
)
except self.database_engine.module.IntegrityError:
return None

stack = AsyncExitStack()

for lock in locks:
await stack.enter_async_context(lock)

return stack

def _try_acquire_multi_read_write_lock_txn(
self,
txn: LoggingTransaction,
lock_names: Collection[Tuple[str, str]],
write: bool,
) -> Collection["Lock"]:
locks = []

for lock_name, lock_key in lock_names:
lock = self._try_acquire_read_write_lock_txn(
txn, lock_name, lock_key, write
)
locks.append(lock)

return locks


class Lock:
"""An async context manager that manages an acquired lock, ensuring it is
Expand Down
5 changes: 1 addition & 4 deletions tests/handlers/test_worker_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
from synapse.util import Clock

from tests import unittest
Expand Down
Loading