diff --git a/changelog.d/11652.misc b/changelog.d/11652.misc new file mode 100644 index 000000000000..8e405b922674 --- /dev/null +++ b/changelog.d/11652.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. \ No newline at end of file diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index b410eefdc71f..3682cb6a8139 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -673,7 +673,7 @@ def _remove_dead_devices_from_device_inbox_txn( # There's a type mismatch here between how we want to type the row and # what fetchone says it returns, but we silence it because we know that # res can't be None. - res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment] + res = cast(Tuple[Optional[int]], txn.fetchone()) if res[0] is None: # this can only happen if the `device_inbox` table is empty, in which # case we have no work to do. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index bc5ff25d0880..270b30800bf2 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -288,7 +288,7 @@ def _get_auth_chain_ids_txn( new_front = set() for chunk in batch_iter(front, 100): # Pull the auth events either from the cache or DB. - to_fetch = [] # Event IDs to fetch from DB # type: List[str] + to_fetch: List[str] = [] # Event IDs to fetch from DB for event_id in chunk: res = self._event_auth_cache.get(event_id) if res is None: @@ -615,8 +615,8 @@ def _get_auth_chain_difference_txn( # currently walking, either from cache or DB. search, chunk = search[:-100], search[-100:] - found = [] # Results found # type: List[Tuple[str, str, int]] - to_fetch = [] # Event IDs to fetch from DB # type: List[str] + found: List[Tuple[str, str, int]] = [] # Results found + to_fetch: List[str] = [] # Event IDs to fetch from DB for _, event_id in chunk: res = self._event_auth_cache.get(event_id) if res is None: diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 98ea0e884cb3..a98e6b259378 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast import attr @@ -326,7 +326,7 @@ def get_after_receipt( ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() # type: ignore[return-value] + return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt @@ -357,7 +357,7 @@ def get_no_receipt( ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() # type: ignore[return-value] + return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt @@ -434,7 +434,7 @@ def get_after_receipt( ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() # type: ignore[return-value] + return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt @@ -465,7 +465,7 @@ def get_no_receipt( ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) - return txn.fetchall() # type: ignore[return-value] + return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt @@ -662,7 +662,7 @@ def _find_first_stream_ordering_after_ts_txn( The stream ordering """ txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = txn.fetchone()[0] # type: ignore[index] + max_stream_ordering = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_stream_ordering is None: return 0 @@ -731,7 +731,7 @@ def f(txn: LoggingTransaction) -> Optional[Tuple[int]]: " LIMIT 1" ) txn.execute(sql, (stream_ordering,)) - return txn.fetchone() # type: ignore[return-value] + return cast(Optional[Tuple[int]], txn.fetchone()) result = await self.db_pool.runInteraction( "get_time_of_last_push_action_before", f @@ -1029,7 +1029,9 @@ def f( " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return txn.fetchall() # type: ignore[return-value] + return cast( + List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall() + ) push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) return [ diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index cf842803bcd6..cb9ee08fa8e8 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json @@ -63,7 +63,7 @@ def _do_txn(txn: LoggingTransaction) -> int: sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" txn.execute(sql, (user_localpart,)) - max_id = txn.fetchone()[0] # type: ignore[index] + max_id = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_id is None: filter_id = 0 else: diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 1b076683f762..cbba356b4a98 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -23,6 +23,7 @@ Optional, Tuple, Union, + cast, ) from synapse.storage._base import SQLBaseStore @@ -220,7 +221,7 @@ def get_local_media_by_user_paginate_txn( WHERE user_id = ? """ txn.execute(sql, args) - count = txn.fetchone()[0] # type: ignore[index] + count = cast(Tuple[int], txn.fetchone())[0] sql = """ SELECT diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 7ab681ed6f5f..747b4f31df67 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -494,7 +494,7 @@ async def add_pusher( # invalidate, since we the user might not have had a pusher before await self.db_pool.runInteraction( "add_pusher", - self._invalidate_cache_and_stream, # type: ignore + self._invalidate_cache_and_stream, # type: ignore[attr-defined] self.get_if_user_has_pusher, (user_id,), ) @@ -503,7 +503,7 @@ async def delete_pusher_by_app_id_pushkey_user_id( self, app_id: str, pushkey: str, user_id: str ) -> None: def delete_pusher_txn(txn, stream_id): - self._invalidate_cache_and_stream( # type: ignore + self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_if_user_has_pusher, (user_id,) ) @@ -548,7 +548,7 @@ async def delete_all_pushers_for_user(self, user_id: str) -> None: pushers = list(await self.get_pushers_by_user_id(user_id)) def delete_pushers_txn(txn, stream_ids): - self._invalidate_cache_and_stream( # type: ignore + self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_if_user_has_pusher, (user_id,) ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 29d9d4de9627..4175c82a253c 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr @@ -1357,12 +1357,15 @@ def _use_registration_token_txn(txn): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - res: Dict[str, Any] = self.db_pool.simple_select_one_txn( - txn, - "registration_tokens", - keyvalues={"token": token}, - retcols=["pending", "completed"], - ) # type: ignore + res = cast( + Dict[str, Any], + self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ), + ) # Decrement pending and increment completed self.db_pool.simple_update_one_txn( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 729ff17e2e19..4ff6aed253a3 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, cast import attr @@ -399,7 +399,7 @@ def _get_thread_summary_txn( AND relation_type = ? """ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) - count = txn.fetchone()[0] # type: ignore[index] + count = cast(Tuple[int], txn.fetchone())[0] return count, latest_event_id diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 340ca9e47d47..a1a1a6a14a85 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast import attr @@ -225,11 +225,14 @@ def _set_ui_auth_session_data_txn( self, txn: LoggingTransaction, session_id: str, key: str, value: Any ): # Get the current value. - result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore - txn, - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - retcols=("serverdict",), + result = cast( + Dict[str, Any], + self.db_pool.simple_select_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + ), ) # Update it and add it back to the database.