Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Remove and switch away from get_create_event_for_room_txn
Browse files Browse the repository at this point in the history
  • Loading branch information
MadLittleMods committed Aug 20, 2021
1 parent fffce99 commit aafa069
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 118 deletions.
18 changes: 14 additions & 4 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,8 +1773,13 @@ def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
# Skip processing an insertion event if the room version doesn't
# support it or the event is not from the room creator.
room_version = self.store.get_room_version_txn(txn, event.room_id)
create_event = self.store.get_create_event_for_room_txn(txn, event.room_id)
room_creator = create_event.content.get("creator", None)
room_creator = self.db_pool.simple_select_one_onecol_txn(
txn,
table="rooms",
keyvalues={"room_id": event.room_id},
retcol="creator",
allow_none=True,
)
if not room_version.msc2716_historical or event.sender != room_creator:
return

Expand Down Expand Up @@ -1826,8 +1831,13 @@ def _handle_chunk_event(self, txn: LoggingTransaction, event: EventBase):
# Skip processing a chunk event if the room version doesn't
# support it or the event is not from the room creator.
room_version = self.store.get_room_version_txn(txn, event.room_id)
create_event = self.store.get_create_event_for_room_txn(txn, event.room_id)
room_creator = create_event.content.get("creator", None)
room_creator = self.db_pool.simple_select_one_onecol_txn(
txn,
table="rooms",
keyvalues={"room_id": event.room_id},
retcol="creator",
allow_none=True,
)
if not room_version.msc2716_historical or event.sender != room_creator:
return

Expand Down
70 changes: 1 addition & 69 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
Expand Down Expand Up @@ -223,74 +223,6 @@ async def get_received_ts(self, event_id: str) -> Optional[int]:
desc="get_received_ts",
)

# Inform mypy that if allow_none is False (the default) then get_event

# always returns an EventBase.
@overload
def get_event_txn(
self,
event_id: str,
allow_rejected: bool = False,
allow_none: Literal[False] = False,
) -> EventBase:
...

@overload
def get_event_txn(
self,
event_id: str,
allow_rejected: bool = False,
allow_none: Literal[True] = False,
) -> Optional[EventBase]:
...

def get_event_txn(
self,
txn: LoggingTransaction,
event_id: str,
allow_rejected: bool = False,
allow_none: bool = False,
) -> Optional[EventBase]:
"""Get an event from the database by event_id.
Args:
txn: Transaction object
event_id: The event_id of the event to fetch
get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected: If True, return rejected events. Otherwise,
behave as per allow_none.
allow_none: If True, return None if no event found, if
False throw a NotFoundError
check_room_id: if not None, check the room of the found event.
If there is a mismatch, behave as per allow_none.
Returns:
The event, or None if the event was not found and allow_none=True
Raises:
NotFoundError: if the event_id was not found and allow_none=False
"""
event_map = self._fetch_event_rows(txn, [event_id])
event_info = event_map[event_id]
if event_info is None and not allow_none:
raise NotFoundError("Could not find event %s" % (event_id,))

rejected_reason = event_info["rejected_reason"]
if not allow_rejected and rejected_reason:
return

d = db_to_json(event_info["json"])
internal_metadata = db_to_json(event_info["internal_metadata"])
room_version_id = event_info["room_version_id"]
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)

event = make_event_from_dict(
event_dict=d,
room_version=room_version,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)

return event

# Inform mypy that if allow_none is False (the default) then get_event
# always returns an EventBase.
@overload
Expand Down
57 changes: 12 additions & 45 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,34 +178,15 @@ async def get_create_event_for_room(self, room_id: str) -> EventBase:
Raises:
NotFoundError if the room is unknown
"""
return await self.db_pool.runInteraction(
"get_create_event_for_room_txn",
self.get_create_event_for_room_txn,
room_id,
)

def get_create_event_for_room_txn(
self, txn: LoggingTransaction, room_id: str
) -> EventBase:
"""Get the create state event for a room.
Args:
txn: Transaction object
room_id: The room ID.
Returns:
The room creation event.
Raises:
NotFoundError if the room is unknown
"""

state_ids = self.get_current_state_ids_txn(txn, room_id)
state_ids = await self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, ""))

# If we can't find the create event, assume we've hit a dead end
if not create_id:
raise NotFoundError("Unknown room %s" % (room_id,))

# Retrieve the room's create event and return
create_event = self.get_event_txn(txn, create_id)
create_event = await self.get_event(create_id)
return create_event

@cached(max_entries=100000, iterable=True)
Expand All @@ -219,35 +200,21 @@ async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
Returns:
The current state of the room.
"""
return await self.db_pool.runInteraction(
"get_current_state_ids_txn",
self.get_current_state_ids_txn,
room_id,
)

def get_current_state_ids_txn(
self, txn: LoggingTransaction, room_id: str
) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.

Args:
txn: Transaction object
room_id: The room to get the state IDs of.
def _get_current_state_ids_txn(txn):
txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
""",
(room_id,),
)

Returns:
The current state of the room.
"""
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}

txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
""",
(room_id,),
return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
)

return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}

# FIXME: how should this be cached?
async def get_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
Expand Down

0 comments on commit aafa069

Please sign in to comment.