Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert events worker database to async/await. #8071

Merged
merged 15 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8071.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
2 changes: 1 addition & 1 deletion synapse/event_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def check(
Args:
room_version_obj: the version of the room
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
auth_events: the existing room state.

Raises:
AuthError if the checks fail
Expand Down
18 changes: 9 additions & 9 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,7 @@ async def send_invite(self, target_host, event):
return pdu

async def on_event_auth(self, event_id: str) -> List[EventBase]:
event = await self.store.get_event(event_id)
event = await self.store.get_event(event_id) # type: EventBase # type: ignore
auth = await self.store.get_auth_chain(
list(event.auth_event_ids()), include_given=True
)
Expand Down Expand Up @@ -1778,8 +1778,8 @@ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase
"""

event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event_id, check_room_id=room_id
) # type: EventBase # type: ignore

state_groups = await self.state_store.get_state_groups(room_id, [event_id])

Expand All @@ -1806,8 +1806,8 @@ async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event_id, check_room_id=room_id
) # type: EventBase # type: ignore

state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])

Expand Down Expand Up @@ -2155,9 +2155,9 @@ async def _check_for_soft_fail(
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]

current_auth_events = await self.store.get_events(current_state_ids)
auth_events_map = await self.store.get_events(current_state_ids)
current_auth_events = {
(e.type, e.state_key): e for e in current_auth_events.values()
(e.type, e.state_key): e for e in auth_events_map.values()
}

try:
Expand All @@ -2174,8 +2174,8 @@ async def on_query_auth(
raise AuthError(403, "Host not in room.")

event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event_id, check_room_id=room_id
) # type: EventBase # type: ignore

# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
Expand Down
20 changes: 10 additions & 10 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

from canonicaljson import encode_canonical_json, json

Expand Down Expand Up @@ -644,7 +644,7 @@ async def send_nonmember_event(
event: EventBase,
context: EventContext,
ratelimit: bool = True,
) -> int:
) -> Union[int, EventBase]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a bug that we should fix rather than document?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I wasn't sure it made sense to tie up in this PR though. I'll probably split out. Do you have any opinion about what this should return in the "error" case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #8093.

"""
Persists and notifies local clients and federation of an event.

Expand Down Expand Up @@ -682,7 +682,7 @@ async def send_nonmember_event(

async def deduplicate_state_event(
self, event: EventBase, context: EventContext
) -> None:
) -> Optional[EventBase]:
"""
Checks whether event is in the latest resolved state in context.

Expand All @@ -692,25 +692,25 @@ async def deduplicate_state_event(
prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return
return None
prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
return None

if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
return prev_event
return
return None

async def create_and_send_nonmember_event(
self,
requester: Requester,
event_dict: dict,
ratelimit: bool = True,
txn_id: Optional[str] = None,
) -> Tuple[EventBase, int]:
) -> Tuple[EventBase, Union[int, EventBase]]:
"""
Creates an event, then sends it.

Expand Down Expand Up @@ -957,7 +957,7 @@ async def persist_and_notify_client_event(
allow_none=True,
)

is_admin_redaction = (
is_admin_redaction = bool(
original_event and event.sender != original_event.sender
)

Expand Down Expand Up @@ -1077,8 +1077,8 @@ def is_inviter_member_event(e):
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
auth_events_map = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}

room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
Expand Down
32 changes: 22 additions & 10 deletions synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,17 @@ async def _local_membership_update(
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(
prev_member_event_id
) # type: EventBase # type: ignore
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(
prev_member_event_id
) # type: EventBase # type: ignore
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target, room_id)

Expand Down Expand Up @@ -694,13 +698,17 @@ async def send_membership_event(
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(
prev_member_event_id
) # type: EventBase # type: ignore
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
await self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(
prev_member_event_id
) # type: EventBase # type: ignore
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)

Expand All @@ -714,9 +722,11 @@ async def _can_guest_join(
if not guest_access_id:
return False

guest_access = await self.store.get_event(guest_access_id)
guest_access = await self.store.get_event(
guest_access_id
) # type: EventBase # type: ignore

return (
return bool(
guest_access
and guest_access.content
and "guest_access" in guest_access.content
Expand Down Expand Up @@ -772,7 +782,7 @@ async def do_3pid_invite(
requester: Requester,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
) -> Union[int, EventBase]:
if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
Expand Down Expand Up @@ -806,7 +816,7 @@ async def do_3pid_invite(
if invitee:
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
) # type: Tuple[Any, Union[int, EventBase]] # type: ignore
else:
stream_id = await self._make_and_store_3pid_invite(
requester,
Expand All @@ -831,7 +841,7 @@ async def _make_and_store_3pid_invite(
user: UserID,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
) -> Union[int, EventBase]:
room_state = await self.state_handler.get_current_state(room_id)

inviter_display_name = ""
Expand Down Expand Up @@ -1066,7 +1076,9 @@ async def remote_reject_invite(

Implements RoomMemberHandler.remote_reject_invite
"""
invite_event = await self.store.get_event(invite_event_id)
invite_event = await self.store.get_event(
invite_event_id
) # type: EventBase # type: ignore
room_id = invite_event.room_id
target_user = invite_event.state_key

Expand Down
3 changes: 2 additions & 1 deletion synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from twisted.internet.protocol import ReconnectingClientFactory

from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.streams import TypingStream
Expand Down Expand Up @@ -145,7 +146,7 @@ async def on_rdata(

event = await self.store.get_event(
row.data.event_id, allow_rejected=True
)
) # type: EventBase # type: ignore
if event.rejected_reason:
continue

Expand Down
2 changes: 1 addition & 1 deletion synapse/spam_checker_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred
state_ids = yield self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
state = yield self._store.get_events(state_ids.values())
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
return state.values()
2 changes: 1 addition & 1 deletion synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def get_events(self, event_ids, allow_rejected=False):
allow_rejected (bool): If True return rejected events.

Returns:
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
"""

return self.store.get_events(
Expand Down
30 changes: 14 additions & 16 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False):
async def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.

Args:
Expand All @@ -40,9 +40,10 @@ def get_auth_chain(self, event_ids, include_given=False):
Returns:
list of events
"""
return self.get_auth_chain_ids(
event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
)
return await self.get_events_as_list(event_ids)

def get_auth_chain_ids(
self,
Expand Down Expand Up @@ -472,7 +473,7 @@ def get_forward_extremeties_for_room_txn(txn):
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)

def get_backfill_events(self, room_id, event_list, limit):
async def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`

Expand All @@ -482,17 +483,15 @@ def get_backfill_events(self, room_id, event_list, limit):
event_list (list)
limit (int)
"""
return (
self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
.addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
event_ids = await self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
events = await self.get_events_as_list(event_ids)
return sorted(events, key=lambda e: -e.depth)

def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
Expand Down Expand Up @@ -553,8 +552,7 @@ async def get_missing_events(self, room_id, earliest_events, latest_events, limi
latest_events,
limit,
)
events = await self.get_events_as_list(ids)
return events
return await self.get_events_as_list(ids)

def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):

Expand Down
Loading