Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add ability to wait on multiple rw locks
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Jul 19, 2023
1 parent 5db4695 commit 9afe229
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 69 deletions.
110 changes: 107 additions & 3 deletions synapse/handlers/worker_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@

import random
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
from typing import (
TYPE_CHECKING,
AsyncContextManager,
Collection,
Dict,
Optional,
Tuple,
Type,
Union,
)
from weakref import WeakSet

import attr
Expand Down Expand Up @@ -47,7 +56,9 @@ def __init__(self, hs: "HomeServer") -> None:

# Map from lock name/key to set of `WaitingLock` that are active for
# that lock.
self._locks: Dict[Tuple[str, str], WeakSet[WaitingLock]] = {}
self._locks: Dict[
Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]]
] = {}

self._clock.looping_call(self._cleanup_locks, 30_000)

Expand Down Expand Up @@ -103,6 +114,29 @@ def acquire_read_write_lock(

return lock

def acquire_multi_read_write_lock(
self,
lock_names: Collection[Tuple[str, str]],
*,
write: bool,
) -> "WaitingMultiLock":
"""Acquires multi read/write locks at once, returns a context manager
that will block until all the locks are acquired.
"""

lock = WaitingMultiLock(
lock_names=lock_names,
write=write,
reactor=self._reactor,
store=self._store,
handler=self,
)

for lock_name, lock_key in lock_names:
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)

return lock

def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
"""Notify that a lock has been released.
Expand All @@ -115,7 +149,7 @@ def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
def _on_lock_released(self, lock_name: str, lock_key: str) -> None:
"""Called when a lock has been released.
Wakes up any locks that might bew waiting on this.
Wakes up any locks that might be waiting on this.
"""
locks = self._locks.get((lock_name, lock_key))
if not locks:
Expand Down Expand Up @@ -201,3 +235,73 @@ def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
return next * random.uniform(0.9, 1.1)


@attr.s(auto_attribs=True, eq=False)
class WaitingMultiLock:
lock_names: Collection[Tuple[str, str]]

write: bool

reactor: IReactorTime
store: LockStore
handler: WorkerLocksHandler

deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)

_inner_lock_cm: Optional[AsyncContextManager] = None
_retry_interval: float = 0.1
_lock_span: "opentracing.Scope" = attr.Factory(
lambda: start_active_span("WaitingLock.lock")
)

async def __aenter__(self) -> None:
self._lock_span.__enter__()

with start_active_span("WaitingLock.waiting_for_lock"):
while self._inner_lock_cm is None:
lock_cm = await self.store.try_acquire_multi_read_write_lock(
self.lock_names, write=self.write
)

if lock_cm:
self._inner_lock_cm = lock_cm
break

self.deferred = defer.Deferred()
try:
with PreserveLoggingContext():
await timeout_deferred(
deferred=self.deferred,
timeout=self._get_next_retry_interval(),
reactor=self.reactor,
)
except Exception:
pass

assert self._inner_lock_cm
await self._inner_lock_cm.__aenter__()
return

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
assert self._inner_lock_cm

for lock_name, lock_key in self.lock_names:
self.handler.notify_lock_released(lock_name, lock_key)

try:
r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb)
finally:
self._lock_span.__exit__(exc_type, exc, tb)

return r

def _get_next_retry_interval(self) -> float:
next = self._retry_interval
self._retry_interval = max(5, next * 2)
return next * random.uniform(0.9, 1.1)
187 changes: 121 additions & 66 deletions synapse/storage/databases/main/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import AsyncExitStack
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type
from weakref import WeakValueDictionary

from twisted.internet.interfaces import IReactorCore
Expand Down Expand Up @@ -208,76 +209,85 @@ async def try_acquire_read_write_lock(
used (otherwise the lock will leak).
"""

try:
lock = await self.db_pool.runInteraction(
"try_acquire_read_write_lock",
self._try_acquire_read_write_lock_txn,
lock_name,
lock_key,
write,
)
except self.database_engine.module.IntegrityError:
return None

return lock

def _try_acquire_read_write_lock_txn(
self,
txn: LoggingTransaction,
lock_name: str,
lock_key: str,
write: bool,
) -> "Lock":
# We attempt to acquire the lock by inserting into
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
#
# Before that though we clear the table of any stale locks.

now = self._clock.time_msec()
token = random_string(6)

def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
# We attempt to acquire the lock by inserting into
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
#
# Before that though we clear the table of any stale locks.

delete_sql = """
DELETE FROM worker_read_write_locks
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
"""

insert_sql = """
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?, ?)
"""

if isinstance(self.database_engine, PostgresEngine):
# For Postgres we can send these queries at the same time.
txn.execute(
delete_sql + ";" + insert_sql,
(
# DELETE args
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
# UPSERT args
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
else:
# For SQLite these need to be two queries.
txn.execute(
delete_sql,
(
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
),
)
txn.execute(
insert_sql,
(
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
delete_sql = """
DELETE FROM worker_read_write_locks
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
"""

return
insert_sql = """
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?, ?)
"""

try:
await self.db_pool.runInteraction(
"try_acquire_read_write_lock",
_try_acquire_read_write_lock_txn,
if isinstance(self.database_engine, PostgresEngine):
# For Postgres we can send these queries at the same time.
txn.execute(
delete_sql + ";" + insert_sql,
(
# DELETE args
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
# UPSERT args
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
else:
# For SQLite these need to be two queries.
txn.execute(
delete_sql,
(
now - _LOCK_TIMEOUT_MS,
lock_name,
lock_key,
),
)
txn.execute(
insert_sql,
(
lock_name,
lock_key,
write,
self._instance_name,
token,
now,
),
)
except self.database_engine.module.IntegrityError:
return None

lock = Lock(
self._reactor,
Expand All @@ -289,10 +299,55 @@ def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
token=token,
)

self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
def set_lock() -> None:
self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock

txn.call_after(set_lock)

return lock

async def try_acquire_multi_read_write_lock(
self,
lock_names: Collection[Tuple[str, str]],
write: bool,
) -> Optional[AsyncExitStack]:
"""Try to acquire a lock for the given name/key. Will return an async
context manager if the lock is successfully acquired, which *must* be
used (otherwise the lock will leak).
"""
try:
locks = await self.db_pool.runInteraction(
"try_acquire_multi_read_write_lock",
self._try_acquire_multi_read_write_lock_txn,
lock_names,
write,
)
except self.database_engine.module.IntegrityError:
return None

stack = AsyncExitStack()

for lock in locks:
await stack.enter_async_context(lock)

return stack

def _try_acquire_multi_read_write_lock_txn(
self,
txn: LoggingTransaction,
lock_names: Collection[Tuple[str, str]],
write: bool,
) -> Collection["Lock"]:
locks = []

for lock_name, lock_key in lock_names:
lock = self._try_acquire_read_write_lock_txn(
txn, lock_name, lock_key, write
)
locks.append(lock)

return locks


class Lock:
"""An async context manager that manages an acquired lock, ensuring it is
Expand Down

0 comments on commit 9afe229

Please sign in to comment.