Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to synapse/storage/databases/main/e2e_room_keys.py #11549

Merged
merged 8 commits into from
Dec 14, 2021
66 changes: 55 additions & 11 deletions synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,18 @@ async def update_e2e_room_key(
Raises:
StoreError
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No row found")
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
"version": version,
"version": version_int,
"room_id": room_id,
"session_id": session_id,
},
Expand All @@ -85,13 +91,19 @@ async def add_e2e_room_keys(
version: the version ID of the backup for the set of keys we're adding to
room_keys: the keys to add, in the form (roomID, sessionID, keyData)
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No row found")

values = []
for (room_id, session_id, room_key) in room_keys:
values.append(
{
"user_id": user_id,
"version": version,
"version": version_int,
"room_id": room_id,
"session_id": session_id,
"first_message_index": room_key["first_message_index"],
Expand Down Expand Up @@ -203,20 +215,26 @@ async def get_e2e_room_keys_multi(
Returns:
A map of room IDs to session IDs to room key
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return {}

return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
version,
version_int,
room_keys,
)

@staticmethod
def _get_e2e_room_keys_multi_txn(
txn: LoggingTransaction,
user_id: str,
version: str,
version: int,
room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
) -> Dict[str, Dict[str, RoomKey]]:
if not room_keys:
Expand Down Expand Up @@ -272,10 +290,16 @@ async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return 0

return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
keyvalues={"user_id": user_id, "version": version_int},
retcol="COUNT(*)",
desc="count_e2e_room_keys",
)
Expand All @@ -301,8 +325,14 @@ async def delete_e2e_room_keys(
If not specified, we delete all the keys in this version of
the backup (or for the specified room)
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

keyvalues = {"user_id": user_id, "version": int(version)}
keyvalues = {"user_id": user_id, "version": version_int}
if room_id:
keyvalues["room_id"] = room_id
if session_id:
Expand All @@ -319,6 +349,8 @@ def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
"WHERE user_id=? AND deleted=0",
(user_id,),
)
# `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
# be `NULL` when there are no available versions.
row = cast(Tuple[Optional[int]], txn.fetchone())
if row[0] is None:
raise StoreError(404, "No current backup version")
Expand Down Expand Up @@ -395,7 +427,7 @@ def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
if current_version is None:
current_version = 0

new_version = str(int(current_version) + 1)
new_version = current_version + 1

self.db_pool.simple_insert_txn(
txn,
Expand All @@ -408,7 +440,7 @@ def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
},
)

return new_version
return str(new_version)

return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
Expand Down Expand Up @@ -440,9 +472,16 @@ async def update_e2e_room_keys_version(
updatevalues["etag"] = version_etag

if updatevalues:
await self.db_pool.simple_update(
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No row found")

await self.db_pool.simple_update_one(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
keyvalues={"user_id": user_id, "version": version_int},
updatevalues=updatevalues,
desc="update_e2e_room_keys_version",
)
Expand All @@ -467,7 +506,12 @@ def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
if version is None:
this_version = self._get_current_version(txn, user_id)
else:
this_version = int(version)
try:
this_version = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it isn't there.
raise StoreError(404, "No row found")

self.db_pool.simple_delete_txn(
txn,
Expand Down