Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Get the cache storage passing mypy #11354

Closed
wants to merge 17 commits into from
30 changes: 25 additions & 5 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter

Expand All @@ -39,16 +43,24 @@
# based on the current state when notifying workers over replication.
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"

# Corresponds to the (cache_func, keys, invalidation_ts) db columns.
_CacheData = Tuple[str, Optional[List[str]], Optional[int]]


class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()

async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
) -> Tuple[List[Tuple[int, _CacheData]], int, bool]:
"""Get updates for caches replication stream.

Args:
Expand All @@ -73,7 +85,9 @@ async def get_all_updated_caches(
if last_id == current_id:
return [], current_id, False

def get_all_updated_caches_txn(txn):
def get_all_updated_caches_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, _CacheData]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
Expand All @@ -85,7 +99,13 @@ def get_all_updated_caches_txn(txn):
LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
updates = [(row[0], row[1:]) for row in txn]
updates: List[Tuple[int, _CacheData]] = []
row: Tuple[int, str, Optional[List[str]], Optional[int]]
# Type saftey: iterating over `txn` yields `Tuple`, i.e.
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
# variadic tuple to a fixed length tuple and flags it up as an error.
for row in txn: # type: ignore[assignment]
updates.append((row[0], row[1:]))
clokep marked this conversation as resolved.
Show resolved Hide resolved
limited = False
upto_token = current_id
if len(updates) >= limit:
Expand Down