From 08cfd7f5376774312a10b171d47126721e618227 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 8 Apr 2022 19:02:27 +0200 Subject: [PATCH 1/9] Add some type hints to datastore --- synapse/storage/databases/main/appservice.py | 28 ++-- .../storage/databases/main/registration.py | 129 +++++++++++------- synapse/storage/databases/main/relations.py | 2 +- synapse/storage/databases/main/signatures.py | 2 +- synapse/storage/databases/main/state.py | 2 +- synapse/storage/databases/main/stream.py | 27 +++- synapse/storage/databases/main/tags.py | 4 +- 7 files changed, 120 insertions(+), 74 deletions(-) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index eb32c34a855f..fda5574ef4d0 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Pattern, Tuple from synapse.appservice import ( ApplicationService, @@ -26,7 +26,11 @@ from synapse.config.appservice import load_appservices from synapse.events import EventBase from synapse.storage._base import db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor @@ -92,7 +96,7 @@ def get_max_as_txn_id(txn: Cursor) -> int: super().__init__(database, db_conn, hs) - def get_app_services(self): + def get_app_services(self) -> List[ApplicationService]: return self.services_cache def get_if_app_services_interested_in_user(self, user_id: str) -> bool: @@ -256,7 +260,7 @@ async def create_appservice_txn( A new transaction. """ - def _create_appservice_txn(txn): + def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction: new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn) # Insert new txn into txn table @@ -291,7 +295,7 @@ async def complete_appservice_txn( service: The application service which was sent this transaction. """ - def _complete_appservice_txn(txn): + def _complete_appservice_txn(txn: LoggingTransaction) -> None: # Set current txn_id for AS to 'txn_id' self.db_pool.simple_upsert_txn( txn, @@ -322,7 +326,9 @@ async def get_oldest_unsent_txn( An AppServiceTransaction or None. """ - def _get_oldest_unsent_txn(txn): + def _get_oldest_unsent_txn( + txn: LoggingTransaction, + ) -> Optional[Dict[str, Any]]: # Monotonically increasing txn ids, so just select the smallest # one in the txns table (we delete them when they are sent) txn.execute( @@ -364,7 +370,7 @@ def _get_oldest_unsent_txn(txn): ) async def set_appservice_last_pos(self, pos: int) -> None: - def set_appservice_last_pos_txn(txn): + def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None: txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) @@ -378,7 +384,9 @@ async def get_new_events_for_appservice( ) -> Tuple[int, List[EventBase]]: """Get all new events for an appservice""" - def get_new_events_for_appservice_txn(txn): + def get_new_events_for_appservice_txn( + txn: LoggingTransaction, + ) -> Tuple[int, Collection[str]]: sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" @@ -416,7 +424,7 @@ async def get_type_stream_id_for_appservice( % (type,) ) - def get_type_stream_id_for_appservice_txn(txn): + def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int: stream_id_type = "%s_stream_id" % type txn.execute( # We do NOT want to escape `stream_id_type`. @@ -444,7 +452,7 @@ async def set_appservice_stream_type_pos( % (stream_type,) ) - def set_appservice_stream_type_pos_txn(txn): + def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None: stream_id_type = "%s_stream_id" % stream_type txn.execute( "UPDATE application_services_state SET %s = ? WHERE as_id=?" diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c7634c92fd37..bfed5edf2799 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -34,7 +34,7 @@ from synapse.storage.types import Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID, UserInfo +from synapse.types import JsonDict, UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -79,7 +79,7 @@ class TokenLookupResult: # Make the token owner default to the user ID, which is the common case. @token_owner.default - def _default_token_owner(self): + def _default_token_owner(self) -> str: return self.user_id @@ -121,7 +121,7 @@ def __init__( database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) self.config: HomeServerConfig = hs.config @@ -299,7 +299,7 @@ async def set_account_validity_for_user( the account. """ - def set_account_validity_for_user_txn(txn): + def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_update_txn( txn=txn, table="account_validity", @@ -394,7 +394,9 @@ async def get_users_expiring_soon(self) -> List[Dict[str, Any]]: A list of dictionaries, each with a user ID and expiration time (in milliseconds). """ - def select_users_txn(txn, now_ms, renew_at): + def select_users_txn( + txn: LoggingTransaction, now_ms: int, renew_at: int + ) -> List[Dict[str, Any]]: sql = ( "SELECT user_id, expiration_ts_ms FROM account_validity" " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" @@ -466,7 +468,7 @@ async def set_server_admin(self, user: UserID, admin: bool) -> None: admin: true iff the user is to be a server admin, false otherwise. """ - def set_server_admin_txn(txn): + def set_server_admin_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_update_one_txn( txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} ) @@ -515,7 +517,7 @@ async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> N user_type: type of the user or None for a user without a type. """ - def set_user_type_txn(txn): + def set_user_type_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_update_one_txn( txn, "users", {"name": user.to_string()}, {"user_type": user_type} ) @@ -525,7 +527,9 @@ def set_user_type_txn(txn): await self.db_pool.runInteraction("set_user_type", set_user_type_txn) - def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: + def _query_for_auth( + self, txn: LoggingTransaction, token: str + ) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, users.is_guest, @@ -582,7 +586,7 @@ async def is_support_user(self, user_id: str) -> bool: "is_support_user", self.is_support_user_txn, user_id ) - def is_real_user_txn(self, txn, user_id): + def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool: res = self.db_pool.simple_select_one_onecol_txn( txn=txn, table="users", @@ -592,7 +596,7 @@ def is_real_user_txn(self, txn, user_id): ) return res is None - def is_support_user_txn(self, txn, user_id): + def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool: res = self.db_pool.simple_select_one_onecol_txn( txn=txn, table="users", @@ -609,10 +613,11 @@ async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str] A mapping of user_id -> password_hash. """ - def f(txn): + def f(txn: LoggingTransaction) -> Dict[str, str]: sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)" txn.execute(sql, (user_id,)) - return dict(txn) + result = cast(List[Tuple[str, str]], txn.fetchall()) + return dict(result) return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) @@ -734,7 +739,7 @@ def _remove_user_external_ids_txn( def _replace_user_external_id_txn( txn: LoggingTransaction, - ): + ) -> None: _remove_user_external_ids_txn(txn, user_id) for auth_provider, external_id in record_external_ids: @@ -790,10 +795,10 @@ async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]: ) return [(r["auth_provider"], r["external_id"]) for r in res] - async def count_all_users(self): + async def count_all_users(self) -> int: """Counts all users registered on the homeserver.""" - def _count_users(txn): + def _count_users(txn: LoggingTransaction) -> int: txn.execute("SELECT COUNT(*) AS users FROM users") rows = self.db_pool.cursor_to_dict(txn) if rows: @@ -810,7 +815,7 @@ async def count_daily_user_type(self) -> Dict[str, int]: who registered on the homeserver in the past 24 hours """ - def _count_daily_user_type(txn): + def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]: yesterday = int(self._clock.time()) - (60 * 60 * 24) sql = """ @@ -835,23 +840,23 @@ def _count_daily_user_type(txn): "count_daily_user_type", _count_daily_user_type ) - async def count_nonbridged_users(self): - def _count_users(txn): + async def count_nonbridged_users(self) -> int: + def _count_users(txn: LoggingTransaction) -> int: txn.execute( """ SELECT COUNT(*) FROM users WHERE appservice_id IS NULL """ ) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_users", _count_users) - async def count_real_users(self): + async def count_real_users(self) -> int: """Counts all users without a special user_type registered on the homeserver.""" - def _count_users(txn): + def _count_users(txn: LoggingTransaction) -> int: txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") rows = self.db_pool.cursor_to_dict(txn) if rows: @@ -888,7 +893,7 @@ async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[s return user_id def get_user_id_by_threepid_txn( - self, txn, medium: str, address: str + self, txn: LoggingTransaction, medium: str, address: str ) -> Optional[str]: """Returns user id from threepid @@ -925,7 +930,7 @@ async def user_add_threepid( {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]: + async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]: return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, @@ -957,7 +962,7 @@ async def user_delete_threepids(self, user_id: str) -> None: async def add_user_bound_threepid( self, user_id: str, medium: str, address: str, id_server: str - ): + ) -> None: """The server proxied a bind request to the given identity server on behalf of the given user. We need to remember this in case the user asks us to unbind the threepid. @@ -1116,7 +1121,9 @@ async def get_threepid_validation_session( assert address or sid - def get_threepid_validation_session_txn(txn): + def get_threepid_validation_session_txn( + txn: LoggingTransaction, + ) -> Optional[Dict[str, Any]]: sql = """ SELECT address, session_id, medium, client_secret, last_send_attempt, validated_at @@ -1150,7 +1157,7 @@ async def delete_threepid_session(self, session_id: str) -> None: session_id: The ID of the session to delete """ - def delete_threepid_session_txn(txn): + def delete_threepid_session_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="threepid_validation_token", @@ -1170,7 +1177,9 @@ def delete_threepid_session_txn(txn): async def cull_expired_threepid_validation_tokens(self) -> None: """Remove threepid validation tokens with expiry dates that have passed""" - def cull_expired_threepid_validation_tokens_txn(txn, ts): + def cull_expired_threepid_validation_tokens_txn( + txn: LoggingTransaction, ts: int + ) -> None: sql = """ DELETE FROM threepid_validation_token WHERE expires < ? @@ -1184,13 +1193,13 @@ def cull_expired_threepid_validation_tokens_txn(txn, ts): ) @wrap_as_background_process("account_validity_set_expiration_dates") - async def _set_expiration_date_when_missing(self): + async def _set_expiration_date_when_missing(self) -> None: """ Retrieves the list of registered users that don't have an expiration date, and adds an expiration date for each of them. """ - def select_users_with_no_expiration_date_txn(txn): + def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None: """Retrieves the list of registered users with no expiration date from the database, filtering out deactivated users. """ @@ -1213,7 +1222,9 @@ def select_users_with_no_expiration_date_txn(txn): select_users_with_no_expiration_date_txn, ) - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + def set_expiration_date_for_user_txn( + self, txn: LoggingTransaction, user_id: str, use_delta: bool = False + ) -> None: """Sets an expiration date to the account with the given user ID. Args: @@ -1344,7 +1355,7 @@ async def set_registration_token_pending(self, token: str) -> None: token: The registration token pending use """ - def _set_registration_token_pending_txn(txn): + def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None: pending = self.db_pool.simple_select_one_onecol_txn( txn, "registration_tokens", @@ -1358,7 +1369,7 @@ def _set_registration_token_pending_txn(txn): updatevalues={"pending": pending + 1}, ) - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_registration_token_pending", _set_registration_token_pending_txn ) @@ -1372,7 +1383,7 @@ async def use_registration_token(self, token: str) -> None: token: The registration token to be 'used' """ - def _use_registration_token_txn(txn): + def _use_registration_token_txn(txn: LoggingTransaction) -> None: # Normally, res is Optional[Dict[str, Any]]. # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors @@ -1398,7 +1409,7 @@ def _use_registration_token_txn(txn): }, ) - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "use_registration_token", _use_registration_token_txn ) @@ -1416,7 +1427,9 @@ async def get_registration_tokens( A list of dicts, each containing details of a token. """ - def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]): + def select_registration_tokens_txn( + txn: LoggingTransaction, now: int, valid: Optional[bool] + ) -> List[Dict[str, Any]]: if valid is None: # Return all tokens regardless of validity txn.execute("SELECT * FROM registration_tokens") @@ -1523,7 +1536,7 @@ async def create_registration_token( Whether the row was inserted or not. """ - def _create_registration_token_txn(txn): + def _create_registration_token_txn(txn: LoggingTransaction) -> bool: row = self.db_pool.simple_select_one_txn( txn, "registration_tokens", @@ -1570,7 +1583,9 @@ async def update_registration_token( A dict with all info about the token, or None if token doesn't exist. """ - def _update_registration_token_txn(txn): + def _update_registration_token_txn( + txn: LoggingTransaction, + ) -> Optional[Dict[str, Any]]: try: self.db_pool.simple_update_one_txn( txn, @@ -1651,7 +1666,9 @@ async def lookup_refresh_token( ) -> Optional[RefreshTokenLookupResult]: """Lookup a refresh token with hints about its validity.""" - def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: + def _lookup_refresh_token_txn( + txn: LoggingTransaction, + ) -> Optional[RefreshTokenLookupResult]: txn.execute( """ SELECT @@ -1807,14 +1824,18 @@ def __init__( unique=False, ) - async def _background_update_set_deactivated_flag(self, progress, batch_size): + async def _background_update_set_deactivated_flag( + self, progress: JsonDict, batch_size: int + ) -> int: """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. """ last_user = progress.get("user_id", "") - def _background_update_set_deactivated_flag_txn(txn): + def _background_update_set_deactivated_flag_txn( + txn: LoggingTransaction, + ) -> Tuple[bool, int]: txn.execute( """ SELECT @@ -1886,7 +1907,9 @@ async def set_user_deactivated_status( deactivated, ) - def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool): + def set_user_deactivated_status_txn( + self, txn: LoggingTransaction, user_id: str, deactivated: bool + ) -> None: self.db_pool.simple_update_one_txn( txn=txn, table="users", @@ -2005,7 +2028,9 @@ async def add_refresh_token_to_user( return next_id - def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: + def _set_device_for_access_token_txn( + self, txn: LoggingTransaction, token: str, device_id: str + ) -> str: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, "access_tokens", {"token": token}, "device_id" ) @@ -2084,7 +2109,7 @@ async def register_user( def _register_user( self, - txn, + txn: LoggingTransaction, user_id: str, password_hash: Optional[str], was_guest: bool, @@ -2094,7 +2119,7 @@ def _register_user( admin: bool, user_type: Optional[str], shadow_banned: bool, - ): + ) -> None: user_id_obj = UserID.from_string(user_id) now = int(self._clock.time()) @@ -2181,7 +2206,7 @@ async def user_set_password_hash( pointless. Use flush_user separately. """ - def user_set_password_hash_txn(txn): + def user_set_password_hash_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) @@ -2204,7 +2229,7 @@ async def user_set_consent_version( StoreError(404) if user not found """ - def f(txn): + def f(txn: LoggingTransaction) -> None: self.db_pool.simple_update_one_txn( txn, table="users", @@ -2229,7 +2254,7 @@ async def user_set_consent_server_notice_sent( StoreError(404) if user not found """ - def f(txn): + def f(txn: LoggingTransaction) -> None: self.db_pool.simple_update_one_txn( txn, table="users", @@ -2259,7 +2284,7 @@ async def user_delete_access_tokens( A tuple of (token, token id, device id) for each of the deleted tokens """ - def f(txn): + def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]: keyvalues = {"user_id": user_id} if device_id is not None: keyvalues["device_id"] = device_id @@ -2301,7 +2326,7 @@ def f(txn): return await self.db_pool.runInteraction("user_delete_access_tokens", f) async def delete_access_token(self, access_token: str) -> None: - def f(txn): + def f(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -2313,7 +2338,7 @@ def f(txn): await self.db_pool.runInteraction("delete_access_token", f) async def delete_refresh_token(self, refresh_token: str) -> None: - def f(txn): + def f(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_one_txn( txn, table="refresh_tokens", keyvalues={"token": refresh_token} ) @@ -2353,7 +2378,7 @@ async def validate_threepid_session( """ # Insert everything into a transaction in order to run atomically - def validate_threepid_session_txn(txn): + def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]: row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_session", @@ -2450,7 +2475,7 @@ async def start_or_continue_validation_session( longer be valid """ - def start_or_continue_validation_session_txn(txn): + def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None: # Create or update a validation session self.db_pool.simple_upsert_txn( txn, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 64a78081402e..38070477ee27 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -608,7 +608,7 @@ async def events_have_relations( %s; """ - def _get_if_events_have_relations(txn) -> List[str]: + def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]: clauses: List[str] = [] clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", parent_ids diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index 0518b8b910e0..95148fd2273a 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -26,7 +26,7 @@ class SignatureWorkerStore(EventsWorkerStore): @cached() - def get_event_reference_hash(self, event_id): + def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]: # This is a dummy function to allow get_event_reference_hashes # to use its cache raise NotImplementedError() diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 4a461a0abb1f..ecdc1fdc4c13 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -204,7 +204,7 @@ async def get_current_state_ids(self, room_id: str) -> StateMap[str]: The current state of the room. """ - def _get_current_state_ids_txn(txn): + def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]: txn.execute( """SELECT type, state_key, event_id FROM current_state_events WHERE room_id = ? diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 82e9ef02d269..17fc4aa1c484 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -36,7 +36,18 @@ """ import logging -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + List, + Literal, + Optional, + Set, + Tuple, + cast, +) import attr from frozendict import frozendict @@ -732,7 +743,7 @@ async def get_room_event_before_stream_ordering( A tuple of (stream ordering, topological ordering, event_id) """ - def _f(txn): + def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]: sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" @@ -742,7 +753,7 @@ def _f(txn): " LIMIT 1" ) txn.execute(sql, (room_id, stream_ordering)) - return txn.fetchone() + return cast(Tuple[int, int, str], txn.fetchone()) return await self.db_pool.runInteraction( "get_room_event_before_stream_ordering", _f @@ -770,7 +781,7 @@ def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, - allow_none=False, + allow_none: Literal[False] = False, ) -> int: return self.db_pool.simple_select_one_onecol_txn( txn=txn, @@ -839,7 +850,7 @@ def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int @staticmethod def _set_before_and_after( events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True - ): + ) -> None: """Inserts ordering information to events' internal metadata from the DB rows. @@ -985,7 +996,9 @@ async def get_all_new_events_stream( the `current_id`). """ - def get_all_new_events_stream_txn(txn): + def get_all_new_events_stream_txn( + txn: LoggingTransaction, + ) -> Tuple[int, Collection[str]]: sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" @@ -1331,7 +1344,7 @@ async def paginate_room_events( async def get_id_for_instance(self, instance_name: str) -> int: """Get a unique, immutable ID that corresponds to the given Synapse worker instance.""" - def _get_id_for_instance_txn(txn): + def _get_id_for_instance_txn(txn: LoggingTransaction) -> int: instance_id = self.db_pool.simple_select_one_onecol_txn( txn, table="instance_map", diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index c8e508a910fb..b0f5de67a30d 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -97,7 +97,7 @@ def get_all_updated_tags_txn( ) def get_tag_content( - txn: LoggingTransaction, tag_ids + txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]] ) -> List[Tuple[int, Tuple[str, str, str]]]: sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" results = [] @@ -251,7 +251,7 @@ def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: return self._account_data_id_gen.get_current_token() def _update_revision_txn( - self, txn, user_id: str, room_id: str, next_id: int + self, txn: LoggingTransaction, user_id: str, room_id: str, next_id: int ) -> None: """Update the latest revision of the tags for the given user and room. From 78f265af2675fd381b3b532e5e16aeafa97c09f4 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 8 Apr 2022 19:06:08 +0200 Subject: [PATCH 2/9] newsfile --- changelog.d/12423.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12423.misc diff --git a/changelog.d/12423.misc b/changelog.d/12423.misc new file mode 100644 index 000000000000..e793d08e5e3f --- /dev/null +++ b/changelog.d/12423.misc @@ -0,0 +1 @@ +Add some type hints to datastore. \ No newline at end of file From 5ba96624932ac77ab24b5e54da10890138ac93d2 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Sat, 9 Apr 2022 11:38:54 +0200 Subject: [PATCH 3/9] change `Collection` to `List` --- synapse/storage/databases/main/appservice.py | 4 ++-- synapse/storage/databases/main/stream.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index fda5574ef4d0..fa732edcca08 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Pattern, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple from synapse.appservice import ( ApplicationService, @@ -386,7 +386,7 @@ async def get_new_events_for_appservice( def get_new_events_for_appservice_txn( txn: LoggingTransaction, - ) -> Tuple[int, Collection[str]]: + ) -> Tuple[int, List[str]]: sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 17fc4aa1c484..5783d5b7c9aa 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -998,7 +998,7 @@ async def get_all_new_events_stream( def get_all_new_events_stream_txn( txn: LoggingTransaction, - ) -> Tuple[int, Collection[str]]: + ) -> Tuple[int, List[str]]: sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" From e7fc822cd418af068af319e8fc9ae304ee146dd4 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Sat, 9 Apr 2022 12:00:31 +0200 Subject: [PATCH 4/9] refactor return type of `select_users_txn` --- synapse/handlers/account_validity.py | 4 ++-- synapse/storage/databases/main/registration.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 9d0975f636ab..05a138410e25 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -180,9 +180,9 @@ async def _send_renewal_emails(self) -> None: expiring_users = await self.store.get_users_expiring_soon() if expiring_users: - for user in expiring_users: + for user_id, expiration_ts_ms in expiring_users: await self._send_renewal_email( - user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] + user_id=user_id, expiration_ts=expiration_ts_ms ) async def send_renewal_email_to_user(self, user_id: str) -> None: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index bfed5edf2799..61ad091407d4 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -385,25 +385,25 @@ async def get_renewal_token_for_user(self, user_id: str) -> str: desc="get_renewal_token_for_user", ) - async def get_users_expiring_soon(self) -> List[Dict[str, Any]]: + async def get_users_expiring_soon(self) -> Optional[List[Tuple[str, int]]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - A list of dictionaries, each with a user ID and expiration time (in milliseconds). + A list of tuples, each with a user ID and expiration time (in milliseconds). """ def select_users_txn( txn: LoggingTransaction, now_ms: int, renew_at: int - ) -> List[Dict[str, Any]]: + ) -> Optional[List[Tuple[str, int]]]: sql = ( "SELECT user_id, expiration_ts_ms FROM account_validity" " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" ) values = [False, now_ms, renew_at] txn.execute(sql, values) - return self.db_pool.cursor_to_dict(txn) + return cast(Optional[List[Tuple[str, int]]], txn.fetchall()) return await self.db_pool.runInteraction( "get_users_expiring_soon", From 2efec26d6b2b8508eda5c443dd31e6ee7772860b Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Sat, 9 Apr 2022 12:09:01 +0200 Subject: [PATCH 5/9] correct type hint in `stream.py` --- synapse/storage/databases/main/stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 5783d5b7c9aa..58eacfe2a147 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -42,7 +42,6 @@ Collection, Dict, List, - Literal, Optional, Set, Tuple, @@ -51,6 +50,7 @@ import attr from frozendict import frozendict +from typing_extensions import Literal from twisted.internet import defer @@ -753,7 +753,7 @@ def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]: " LIMIT 1" ) txn.execute(sql, (room_id, stream_ordering)) - return cast(Tuple[int, int, str], txn.fetchone()) + return cast(Optional[Tuple[int, int, str]], txn.fetchone()) return await self.db_pool.runInteraction( "get_room_event_before_stream_ordering", _f From 284a836af6825e87c2bc3839d6490062e5ecd050 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Sat, 9 Apr 2022 12:29:42 +0200 Subject: [PATCH 6/9] Remove `Optional` in `select_users_txn` --- synapse/storage/databases/main/registration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 61ad091407d4..bdd124aa8fca 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -385,7 +385,7 @@ async def get_renewal_token_for_user(self, user_id: str) -> str: desc="get_renewal_token_for_user", ) - async def get_users_expiring_soon(self) -> Optional[List[Tuple[str, int]]]: + async def get_users_expiring_soon(self) -> List[Tuple[str, int]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). @@ -396,14 +396,14 @@ async def get_users_expiring_soon(self) -> Optional[List[Tuple[str, int]]]: def select_users_txn( txn: LoggingTransaction, now_ms: int, renew_at: int - ) -> Optional[List[Tuple[str, int]]]: + ) -> List[Tuple[str, int]]: sql = ( "SELECT user_id, expiration_ts_ms FROM account_validity" " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" ) values = [False, now_ms, renew_at] txn.execute(sql, values) - return cast(Optional[List[Tuple[str, int]]], txn.fetchall()) + return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( "get_users_expiring_soon", From 38533a3857eca33b6420d6387efdb358dcd26458 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Sat, 9 Apr 2022 12:32:45 +0200 Subject: [PATCH 7/9] remove not needed return type in `__init__` --- synapse/storage/databases/main/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index bdd124aa8fca..d43163c27cae 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -121,7 +121,7 @@ def __init__( database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ) -> None: + ): super().__init__(database, db_conn, hs) self.config: HomeServerConfig = hs.config From 6d75f963a74d364bc67c01525940673392b5dcd5 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 11 Apr 2022 17:00:32 +0200 Subject: [PATCH 8/9] Revert change in `get_stream_id_for_event_txn` --- synapse/storage/databases/main/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 58eacfe2a147..9f544cd787f1 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -781,7 +781,7 @@ def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, - allow_none: Literal[False] = False, + allow_none=False, ) -> int: return self.db_pool.simple_select_one_onecol_txn( txn=txn, From 3edfa0c38f3e41b3cfc06b3358c0d34d95797778 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 11 Apr 2022 17:05:22 +0200 Subject: [PATCH 9/9] Remove import from `Literal` --- synapse/storage/databases/main/stream.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 9f544cd787f1..6d45a8a9f6cd 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -50,7 +50,6 @@ import attr from frozendict import frozendict -from typing_extensions import Literal from twisted.internet import defer