Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to some storage classes (#11307)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Nov 11, 2021
1 parent 6ce19b9 commit 64ef253
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 54 deletions.
1 change: 1 addition & 0 deletions changelog.d/11307.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to storage classes.
7 changes: 0 additions & 7 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ exclude = (?x)
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/account_data.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/censor_events.py
|synapse/storage/databases/main/deviceinbox.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/directory.py
|synapse/storage/databases/main/e2e_room_keys.py
Expand All @@ -38,19 +36,15 @@ exclude = (?x)
|synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/events_forward_extremities.py
|synapse/storage/databases/main/events_worker.py
|synapse/storage/databases/main/filtering.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/lock.py
|synapse/storage/databases/main/media_repository.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/openid.py
|synapse/storage/databases/main/presence.py
|synapse/storage/databases/main/profile.py
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/rejections.py
|synapse/storage/databases/main/room.py
|synapse/storage/databases/main/room_batch.py
|synapse/storage/databases/main/roommember.py
Expand All @@ -59,7 +53,6 @@ exclude = (?x)
|synapse/storage/databases/main/state.py
|synapse/storage/databases/main/state_deltas.py
|synapse/storage/databases/main/stats.py
|synapse/storage/databases/main/tags.py
|synapse/storage/databases/main/transactions.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/databases/main/user_erasure_store.py
Expand Down
30 changes: 16 additions & 14 deletions synapse/storage/databases/main/censor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder
Expand All @@ -41,7 +41,7 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)

@wrap_as_background_process("_censor_redactions")
async def _censor_redactions(self):
async def _censor_redactions(self) -> None:
"""Censors all redactions older than the configured period that haven't
been censored yet.
Expand Down Expand Up @@ -105,7 +105,7 @@ async def _censor_redactions(self):
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
pruned_json = json_encoder.encode(
pruned_json: Optional[str] = json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
Expand All @@ -116,7 +116,7 @@ async def _censor_redactions(self):

updates.append((redaction_id, event_id, pruned_json))

def _update_censor_txn(txn):
def _update_censor_txn(txn: LoggingTransaction) -> None:
for redaction_id, event_id, pruned_json in updates:
if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json)
Expand All @@ -130,14 +130,16 @@ def _update_censor_txn(txn):

await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)

def _censor_event_txn(self, txn, event_id, pruned_json):
def _censor_event_txn(
self, txn: LoggingTransaction, event_id: str, pruned_json: str
) -> None:
"""Censor an event by replacing its JSON in the event_json table with the
provided pruned JSON.
Args:
txn (LoggingTransaction): The database transaction.
event_id (str): The ID of the event to censor.
pruned_json (str): The pruned JSON
txn: The database transaction.
event_id: The ID of the event to censor.
pruned_json: The pruned JSON
"""
self.db_pool.simple_update_one_txn(
txn,
Expand All @@ -157,7 +159,7 @@ async def expire_event(self, event_id: str) -> None:
# Try to retrieve the event's content from the database or the event cache.
event = await self.get_event(event_id)

def delete_expired_event_txn(txn):
def delete_expired_event_txn(txn: LoggingTransaction) -> None:
# Delete the expiry timestamp associated with this event from the database.
self._delete_event_expiry_txn(txn, event_id)

Expand Down Expand Up @@ -194,14 +196,14 @@ def delete_expired_event_txn(txn):
"delete_expired_event", delete_expired_event_txn
)

def _delete_event_expiry_txn(self, txn, event_id):
def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None:
"""Delete the expiry timestamp associated with an event ID without deleting the
actual event.
Args:
txn (LoggingTransaction): The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of.
txn: The transaction to use to perform the deletion.
event_id: The event ID to delete the associated expiry timestamp of.
"""
return self.db_pool.simple_delete_txn(
self.db_pool.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
)
52 changes: 36 additions & 16 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,9 +20,17 @@
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
Expand All @@ -34,14 +43,21 @@


class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()

# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
self._last_device_delete_cache: ExpiringCache[
Tuple[str, Optional[str]], int
] = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
Expand All @@ -53,14 +69,16 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name in hs.config.worker.writers.to_device
)

self._device_inbox_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
self._device_inbox_id_gen: AbstractStreamIdGenerator = (
MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
)
)
else:
self._can_write_to_device = True
Expand Down Expand Up @@ -101,6 +119,8 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):

def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
# If replication is happening than postgres must be being used.
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
Expand Down Expand Up @@ -220,11 +240,11 @@ def delete_messages_for_device_txn(txn):
log_kv({"message": f"deleted {count} messages for device", "count": count})

# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
updated_last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
updated_last_deleted_stream_id, up_to_stream_id
)

return count
Expand Down Expand Up @@ -432,7 +452,7 @@ def add_messages_txn(txn, now_ms, stream_id):
)

async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
now_ms = self._clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
Expand Down Expand Up @@ -483,7 +503,7 @@ def add_messages_txn(txn, now_ms, stream_id):
)

async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
now_ms = self._clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
Expand Down
6 changes: 4 additions & 2 deletions synapse/storage/databases/main/filtering.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,6 +19,7 @@

from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached

Expand Down Expand Up @@ -49,7 +51,7 @@ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> i

# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
def _do_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?"
Expand All @@ -61,7 +63,7 @@ def _do_txn(txn):

sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
max_id = txn.fetchone()[0] # type: ignore[index]
if max_id is None:
filter_id = 0
else:
Expand Down
6 changes: 4 additions & 2 deletions synapse/storage/databases/main/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
from typing import TYPE_CHECKING, Optional, Tuple, Type
from weakref import WeakValueDictionary

from twisted.internet.interfaces import IReactorCore
Expand Down Expand Up @@ -62,7 +62,9 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"

# A map from `(lock_name, lock_key)` to the token of any locks that we
# think we currently hold.
self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
self._live_tokens: WeakValueDictionary[
Tuple[str, str], Lock
] = WeakValueDictionary()

# When we shut down we want to remove the locks. Technically this can
# lead to a race, as we may drop the lock while we are still processing.
Expand Down
17 changes: 16 additions & 1 deletion synapse/storage/databases/main/openid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction


class OpenIdStore(SQLBaseStore):
Expand All @@ -20,7 +35,7 @@ async def insert_open_id_token(
async def get_user_id_for_open_id_token(
self, token: str, ts_now_ms: int
) -> Optional[str]:
def get_user_id_for_token_txn(txn):
def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]:
sql = (
"SELECT user_id FROM open_id_tokens"
" WHERE token = ? AND ? <= ts_valid_until_ms"
Expand Down
Loading

0 comments on commit 64ef253

Please sign in to comment.