Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add some type hints to datastore #12423

Merged
merged 10 commits into from
Apr 12, 2022
1 change: 1 addition & 0 deletions changelog.d/12423.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some type hints to datastore.
4 changes: 2 additions & 2 deletions synapse/handlers/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ async def _send_renewal_emails(self) -> None:
expiring_users = await self.store.get_users_expiring_soon()

if expiring_users:
for user in expiring_users:
for user_id, expiration_ts_ms in expiring_users:
await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
user_id=user_id, expiration_ts=expiration_ts_ms
)

async def send_renewal_email_to_user(self, user_id: str) -> None:
Expand Down
28 changes: 18 additions & 10 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple

from synapse.appservice import (
ApplicationService,
Expand All @@ -26,7 +26,11 @@
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
Expand Down Expand Up @@ -92,7 +96,7 @@ def get_max_as_txn_id(txn: Cursor) -> int:

super().__init__(database, db_conn, hs)

def get_app_services(self):
def get_app_services(self) -> List[ApplicationService]:
return self.services_cache

def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
Expand Down Expand Up @@ -256,7 +260,7 @@ async def create_appservice_txn(
A new transaction.
"""

def _create_appservice_txn(txn):
def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)

# Insert new txn into txn table
Expand Down Expand Up @@ -291,7 +295,7 @@ async def complete_appservice_txn(
service: The application service which was sent this transaction.
"""

def _complete_appservice_txn(txn):
def _complete_appservice_txn(txn: LoggingTransaction) -> None:
# Set current txn_id for AS to 'txn_id'
self.db_pool.simple_upsert_txn(
txn,
Expand Down Expand Up @@ -322,7 +326,9 @@ async def get_oldest_unsent_txn(
An AppServiceTransaction or None.
"""

def _get_oldest_unsent_txn(txn):
def _get_oldest_unsent_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
Expand Down Expand Up @@ -364,7 +370,7 @@ def _get_oldest_unsent_txn(txn):
)

async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn):
def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
Expand All @@ -378,7 +384,9 @@ async def get_new_events_for_appservice(
) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice"""

def get_new_events_for_appservice_txn(txn):
def get_new_events_for_appservice_txn(
txn: LoggingTransaction,
) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
Expand Down Expand Up @@ -416,7 +424,7 @@ async def get_type_stream_id_for_appservice(
% (type,)
)

def get_type_stream_id_for_appservice_txn(txn):
def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int:
stream_id_type = "%s_stream_id" % type
txn.execute(
# We do NOT want to escape `stream_id_type`.
Expand Down Expand Up @@ -444,7 +452,7 @@ async def set_appservice_stream_type_pos(
% (stream_type,)
)

def set_appservice_stream_type_pos_txn(txn):
def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
Expand Down
Loading