Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add missing type hints to event fetching. (#11121)
Browse files Browse the repository at this point in the history
Updates the event rows returned from the database to be
attrs classes instead of dictionaries.
  • Loading branch information
clokep authored Oct 19, 2021
1 parent 5e0e683 commit 0dd0c40
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 61 deletions.
1 change: 1 addition & 0 deletions changelog.d/11121.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints for event fetching.
142 changes: 81 additions & 61 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
Expand Down Expand Up @@ -86,6 +87,47 @@ class _EventCacheEntry:
redacted_event: Optional[EventBase]


@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventRow:
"""
An event, as pulled from the database.
Properties:
event_id: The event ID of the event.
stream_ordering: stream ordering for this event
json: json-encoded event structure
internal_metadata: json-encoded internal metadata dict
format_version: The format of the event. Hopefully one of EventFormatVersions.
'None' means the event predates EventFormatVersions (so the event is format V1).
room_version_id: The version of the room which contains the event. Hopefully
one of RoomVersions.
Due to historical reasons, there may be a few events in the database which
do not have an associated room; in this case None will be returned here.
rejected_reason: if the event was rejected, the reason why.
redactions: a list of event-ids which (claim to) redact this event.
outlier: True if this event is an outlier.
"""

event_id: str
stream_ordering: int
json: str
internal_metadata: str
format_version: Optional[int]
room_version_id: Optional[int]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool


class EventRedactBehaviour(Names):
"""
What to do when retrieving a redacted event from the database.
Expand Down Expand Up @@ -686,7 +728,7 @@ async def get_stripped_room_state_from_event_context(
for e in state_to_include.values()
]

def _do_fetch(self, conn):
def _do_fetch(self, conn: Connection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
Expand All @@ -713,13 +755,15 @@ def _do_fetch(self, conn):

self._fetch_event_list(conn, event_list)

def _fetch_event_list(self, conn, event_list):
def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
Args:
conn (twisted.enterprise.adbapi.Connection): database connection
conn: database connection
event_list (list[Tuple[list[str], Deferred]]):
event_list:
The fetch requests. Each entry consists of a list of event
ids to be fetched, and a deferred to be completed once the
events have been fetched.
Expand Down Expand Up @@ -788,7 +832,7 @@ async def _get_events_from_db(
row = row_map.get(event_id)
fetched_events[event_id] = row
if row:
redaction_ids.update(row["redactions"])
redaction_ids.update(row.redactions)

events_to_fetch = redaction_ids.difference(fetched_events.keys())
if events_to_fetch:
Expand All @@ -799,32 +843,32 @@ async def _get_events_from_db(
for event_id, row in fetched_events.items():
if not row:
continue
assert row["event_id"] == event_id
assert row.event_id == event_id

rejected_reason = row["rejected_reason"]
rejected_reason = row.rejected_reason

# If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown.
try:
d = db_to_json(row["json"])
d = db_to_json(row.json)
except ValueError:
logger.error("Unable to parse json from event: %s", event_id)
continue
try:
internal_metadata = db_to_json(row["internal_metadata"])
internal_metadata = db_to_json(row.internal_metadata)
except ValueError:
logger.error(
"Unable to parse internal_metadata from event: %s", event_id
)
continue

format_version = row["format_version"]
format_version = row.format_version
if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1

room_version_id = row["room_version_id"]
room_version_id = row.room_version_id

if not room_version_id:
# this should only happen for out-of-band membership events which
Expand Down Expand Up @@ -889,16 +933,16 @@ async def _get_events_from_db(
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
original_ev.internal_metadata.outlier = row["outlier"]
original_ev.internal_metadata.stream_ordering = row.stream_ordering
original_ev.internal_metadata.outlier = row.outlier

event_map[event_id] = original_ev

# finally, we can decide whether each one needs redacting, and build
# the cache entries.
result_map = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id]["redactions"]
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
Expand All @@ -912,17 +956,17 @@ async def _get_events_from_db(

return result_map

async def _enqueue_events(self, events):
async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
Args:
events (Iterable[str]): events to be fetched.
events: events to be fetched.
Returns:
Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
A map from event id to row data from the database. May contain events
that weren't requested.
"""

events_d = defer.Deferred()
Expand All @@ -949,43 +993,19 @@ async def _enqueue_events(self, events):

return row_map

def _fetch_event_rows(self, txn, event_ids):
def _fetch_event_rows(
self, txn: LoggingTransaction, event_ids: Iterable[str]
) -> Dict[str, _EventRow]:
"""Fetch event rows from the database
Events which are not found are omitted from the result.
The returned per-event dicts contain the following keys:
* event_id (str)
* stream_ordering (int): stream ordering for this event
* json (str): json-encoded event structure
* internal_metadata (str): json-encoded internal metadata dict
* format_version (int|None): The format of the event. Hopefully one
of EventFormatVersions. 'None' means the event predates
EventFormatVersions (so the event is format V1).
* room_version_id (str|None): The version of the room which contains the event.
Hopefully one of RoomVersions.
Due to historical reasons, there may be a few events in the database which
do not have an associated room; in this case None will be returned here.
* rejected_reason (str|None): if the event was rejected, the reason
why.
* redactions (List[str]): a list of event-ids which (claim to) redact
this event.
Args:
txn (twisted.enterprise.adbapi.Connection):
event_ids (Iterable[str]): event IDs to fetch
txn: The database transaction.
event_ids: event IDs to fetch
Returns:
Dict[str, Dict]: a map from event id to event info.
A map from event id to event info.
"""
event_dict = {}
for evs in batch_iter(event_ids, 200):
Expand Down Expand Up @@ -1013,17 +1033,17 @@ def _fetch_event_rows(self, txn, event_ids):

for row in txn:
event_id = row[0]
event_dict[event_id] = {
"event_id": event_id,
"stream_ordering": row[1],
"internal_metadata": row[2],
"json": row[3],
"format_version": row[4],
"room_version_id": row[5],
"rejected_reason": row[6],
"redactions": [],
"outlier": row[7],
}
event_dict[event_id] = _EventRow(
event_id=event_id,
stream_ordering=row[1],
internal_metadata=row[2],
json=row[3],
format_version=row[4],
room_version_id=row[5],
rejected_reason=row[6],
redactions=[],
outlier=row[7],
)

# check for redactions
redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
Expand All @@ -1035,7 +1055,7 @@ def _fetch_event_rows(self, txn, event_ids):
for (redacter, redacted) in txn:
d = event_dict.get(redacted)
if d:
d["redactions"].append(redacter)
d.redactions.append(redacter)

return event_dict

Expand Down

0 comments on commit 0dd0c40

Please sign in to comment.