Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to `synapse/storage/databases/main/events_bg_updates.p…
Browse files Browse the repository at this point in the history
…y` (#11654)
  • Loading branch information
dklimpel authored Dec 30, 2021
1 parent 2c7f5e7 commit 07a3b5d
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 30 deletions.
1 change: 1 addition & 0 deletions changelog.d/11654.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
Expand Down Expand Up @@ -200,6 +199,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.event_push_actions]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.events_bg_updates]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True

Expand Down
69 changes: 40 additions & 29 deletions synapse/storage/databases/main/events_bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast

import attr

Expand Down Expand Up @@ -240,12 +240,14 @@ def __init__(

################################################################################

async def _background_reindex_fields_sender(self, progress, batch_size):
async def _background_reindex_fields_sender(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)

def reindex_txn(txn):
def reindex_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, json FROM events"
" INNER JOIN event_json USING (event_id)"
Expand Down Expand Up @@ -307,12 +309,14 @@ def reindex_txn(txn):

return result

async def _background_reindex_origin_server_ts(self, progress, batch_size):
async def _background_reindex_origin_server_ts(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)

def reindex_search_txn(txn):
def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
Expand Down Expand Up @@ -381,7 +385,9 @@ def reindex_search_txn(txn):

return result

async def _cleanup_extremities_bg_update(self, progress, batch_size):
async def _cleanup_extremities_bg_update(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to clean out extremities that should have been
deleted previously.
Expand All @@ -402,12 +408,12 @@ async def _cleanup_extremities_bg_update(self, progress, batch_size):
# have any descendants, but if they do then we should delete those
# extremities.

def _cleanup_extremities_bg_update_txn(txn):
def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
# The set of extremity event IDs that we're checking this round
original_set = set()

# A dict[str, set[str]] of event ID to their prev events.
graph = {}
# A dict[str, Set[str]] of event ID to their prev events.
graph: Dict[str, Set[str]] = {}

# The set of descendants of the original set that are not rejected
# nor soft-failed. Ancestors of these events should be removed
Expand Down Expand Up @@ -536,7 +542,7 @@ def _cleanup_extremities_bg_update_txn(txn):
room_ids = {row["room_id"] for row in rows}
for room_id in room_ids:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
)

self.db_pool.simple_delete_many_txn(
Expand All @@ -558,7 +564,7 @@ def _cleanup_extremities_bg_update_txn(txn):
_BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
)

def _drop_table_txn(txn):
def _drop_table_txn(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE _extremities_to_check")

await self.db_pool.runInteraction(
Expand All @@ -567,11 +573,11 @@ def _drop_table_txn(txn):

return num_handled

async def _redactions_received_ts(self, progress, batch_size):
async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
"""Handles filling out the `received_ts` column in redactions."""
last_event_id = progress.get("last_event_id", "")

def _redactions_received_ts_txn(txn):
def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
# Fetch the set of event IDs that we want to update
sql = """
SELECT event_id FROM redactions
Expand Down Expand Up @@ -622,10 +628,12 @@ def _redactions_received_ts_txn(txn):

return count

async def _event_fix_redactions_bytes(self, progress, batch_size):
async def _event_fix_redactions_bytes(
self, progress: JsonDict, batch_size: int
) -> int:
"""Undoes hex encoded censored redacted event JSON."""

def _event_fix_redactions_bytes_txn(txn):
def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
# This update is quite fast due to new index.
txn.execute(
"""
Expand All @@ -650,11 +658,11 @@ def _event_fix_redactions_bytes_txn(txn):

return 1

async def _event_store_labels(self, progress, batch_size):
async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")

def _event_store_labels_txn(txn):
def _event_store_labels_txn(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT event_id, json FROM event_json
Expand Down Expand Up @@ -754,7 +762,10 @@ def get_rejected_events(
),
)

return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
return cast(
List[Tuple[str, str, JsonDict, bool, bool]],
[(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
)

results = await self.db_pool.runInteraction(
desc="_rejected_events_metadata_get", func=get_rejected_events
Expand Down Expand Up @@ -912,7 +923,7 @@ async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:

def _calculate_chain_cover_txn(
self,
txn: Cursor,
txn: LoggingTransaction,
last_room_id: str,
last_depth: int,
last_stream: int,
Expand Down Expand Up @@ -1023,10 +1034,10 @@ def _calculate_chain_cover_txn(
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
self.event_chain_id_gen,
self.event_chain_id_gen, # type: ignore[attr-defined]
event_to_room_id,
event_to_types,
event_to_auth_chain,
cast(Dict[str, Sequence[str]], event_to_auth_chain),
)

return _CalculateChainCover(
Expand All @@ -1046,7 +1057,7 @@ async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> in
"""
current_event_id = progress.get("current_event_id", "")

def purged_chain_cover_txn(txn) -> int:
def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
# The event ID from events will be null if the chain ID / sequence
# number points to a purged event.
sql = """
Expand Down Expand Up @@ -1181,14 +1192,14 @@ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
# Iterate the parent IDs and invalidate caches.
for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,)
self._invalidate_cache_and_stream(
txn, self.get_relations_for_event, cache_tuple
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream(
txn, self.get_aggregation_groups_for_event, cache_tuple
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream(
txn, self.get_thread_summary, cache_tuple
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
)

if results:
Expand Down Expand Up @@ -1220,7 +1231,7 @@ async def _background_populate_stream_ordering2(
"""
batch_size = max(batch_size, 1)

def process(txn: Cursor) -> int:
def process(txn: LoggingTransaction) -> int:
last_stream = progress.get("last_stream", -(1 << 31))
txn.execute(
"""
Expand Down

0 comments on commit 07a3b5d

Please sign in to comment.