Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert a synapse.events to async/await #7949

Merged
merged 7 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion changelog.d/7948.misc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Convert push to async/await.
Convert various parts of the codebase to async/await.
1 change: 1 addition & 0 deletions changelog.d/7949.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
2 changes: 1 addition & 1 deletion changelog.d/7951.misc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Convert groups and visibility code to async / await.
Convert various parts of the codebase to async/await.
2 changes: 1 addition & 1 deletion synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, hs):

@defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
Expand Down
19 changes: 8 additions & 11 deletions synapse/events/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import attr
from nacl.signing import SigningKey

from twisted.internet import defer

from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import (
Expand Down Expand Up @@ -95,31 +93,30 @@ def state_key(self):
def is_state(self):
return self._state_key is not None

@defer.inlineCallbacks
def build(self, prev_event_ids):
async def build(self, prev_event_ids):
"""Transform into a fully signed and hashed event

Args:
prev_event_ids (list[str]): The event IDs to use as the prev events

Returns:
Deferred[FrozenEvent]
FrozenEvent
"""

state_ids = yield defer.ensureDeferred(
self._state.get_current_state_ids(self.room_id, prev_event_ids)
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
auth_ids = await self._auth.compute_auth_events(self, state_ids)

format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids)
prev_events = yield self._store.add_event_hashes(prev_event_ids)
auth_events = await self._store.add_event_hashes(auth_ids)
prev_events = await self._store.add_event_hashes(prev_event_ids)
else:
auth_events = auth_ids
prev_events = prev_event_ids

old_depth = yield self._store.get_max_depth_of(prev_event_ids)
old_depth = await self._store.get_max_depth_of(prev_event_ids)
depth = old_depth + 1

# we cap depth of generated events, to ensure that they are not
Expand Down
46 changes: 22 additions & 24 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

import attr
from frozendict import frozendict

from twisted.internet import defer

from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap

if TYPE_CHECKING:
from synapse.storage.data_stores.main import DataStore


@attr.s(slots=True)
class EventContext:
Expand Down Expand Up @@ -129,8 +131,7 @@ def with_state(
delta_ids=delta_ids,
)

@defer.inlineCallbacks
def serialize(self, event, store):
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`

Expand All @@ -146,7 +147,7 @@ def serialize(self, event, store):
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
prev_state_ids = yield self.get_prev_state_ids()
prev_state_ids = await self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
Expand Down Expand Up @@ -214,8 +215,7 @@ def state_group(self) -> Optional[int]:

return self._state_group

@defer.inlineCallbacks
def get_current_state_ids(self):
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``

Expand All @@ -224,32 +224,31 @@ def get_current_state_ids(self):
``rejected`` is set.

Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
Returns None if state_group is None, which happens when the associated
event is an outlier.

Maps a (type, state_key) to the event ID of the state event matching
this tuple.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")

yield self._ensure_fetched()
await self._ensure_fetched()
return self._current_state_ids

@defer.inlineCallbacks
def get_prev_state_ids(self):
async def get_prev_state_ids(self):
"""
Gets the room state map, excluding this event.

For a non-state event, this will be the same as get_current_state_ids().

Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
dict[(str, str), str]|None: Returns None if state_group
is None, which happens when the associated event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
yield self._ensure_fetched()
await self._ensure_fetched()
return self._prev_state_ids

def get_cached_current_state_ids(self):
Expand All @@ -269,8 +268,8 @@ def get_cached_current_state_ids(self):

return self._current_state_ids

def _ensure_fetched(self):
return defer.succeed(None)
async def _ensure_fetched(self):
return None


@attr.s(slots=True)
Expand Down Expand Up @@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)

def _ensure_fetched(self):
async def _ensure_fetched(self):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state)

return make_deferred_yieldable(self._fetching_state_deferred)
return await make_deferred_yieldable(self._fetching_state_deferred)

@defer.inlineCallbacks
def _fill_out_state(self):
async def _fill_out_state(self):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return

self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
self._current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group
)
if self._event_state_key is not None:
Expand Down
55 changes: 30 additions & 25 deletions synapse/events/third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Requester


class ThirdPartyEventRules(object):
Expand All @@ -39,76 +41,79 @@ def __init__(self, hs):
config=config, http_client=hs.get_simple_http_client()
)

@defer.inlineCallbacks
def check_event_allowed(self, event, context):
async def check_event_allowed(
self, event: EventBase, context: EventContext
) -> bool:
"""Check if a provided event should be allowed in the given context.

Args:
event (synapse.events.EventBase): The event to be checked.
context (synapse.events.snapshot.EventContext): The context of the event.
event: The event to be checked.
context: The context of the event.

Returns:
defer.Deferred[bool]: True if the event should be allowed, False if not.
True if the event should be allowed, False if not.
"""
if self.third_party_rules is None:
return True

prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()

# Retrieve the state events from the database.
state_events = {}
for key, event_id in prev_state_ids.items():
state_events[key] = yield self.store.get_event(event_id, allow_none=True)
state_events[key] = await self.store.get_event(event_id, allow_none=True)

ret = yield self.third_party_rules.check_event_allowed(event, state_events)
ret = await self.third_party_rules.check_event_allowed(event, state_events)
return ret

@defer.inlineCallbacks
def on_create_room(self, requester, config, is_requester_admin):
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
) -> bool:
"""Intercept requests to create room to allow, deny or update the
request config.

Args:
requester (Requester)
config (dict): The creation config from the client.
is_requester_admin (bool): If the requester is an admin
requester
config: The creation config from the client.
is_requester_admin: If the requester is an admin

Returns:
defer.Deferred[bool]: Whether room creation is allowed or denied.
Whether room creation is allowed or denied.
"""

if self.third_party_rules is None:
return True

ret = yield self.third_party_rules.on_create_room(
ret = await self.third_party_rules.on_create_room(
requester, config, is_requester_admin
)
return ret

@defer.inlineCallbacks
def check_threepid_can_be_invited(self, medium, address, room_id):
async def check_threepid_can_be_invited(
self, medium: str, address: str, room_id: str
) -> bool:
"""Check if a provided 3PID can be invited in the given room.

Args:
medium (str): The 3PID's medium.
address (str): The 3PID's address.
room_id (str): The room we want to invite the threepid to.
medium: The 3PID's medium.
address: The 3PID's address.
room_id: The room we want to invite the threepid to.

Returns:
defer.Deferred[bool], True if the 3PID can be invited, False if not.
True if the 3PID can be invited, False if not.
"""

if self.third_party_rules is None:
return True

state_ids = yield self.store.get_filtered_current_state_ids(room_id)
room_state_events = yield self.store.get_events(state_ids.values())
state_ids = await self.store.get_filtered_current_state_ids(room_id)
room_state_events = await self.store.get_events(state_ids.values())

state_events = {}
for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id]

ret = yield self.third_party_rules.check_threepid_can_be_invited(
ret = await self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events
)
return ret
15 changes: 7 additions & 8 deletions synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from frozendict import frozendict

from twisted.internet import defer

from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
Expand Down Expand Up @@ -337,8 +335,9 @@ def __init__(self, hs):
hs.config.experimental_msc1849_support_enabled
)

@defer.inlineCallbacks
def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
async def serialize_event(
self, event, time_now, bundle_aggregations=True, **kwargs
):
"""Serializes a single event.

Args:
Expand All @@ -348,7 +347,7 @@ def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
**kwargs: Arguments to pass to `serialize_event`

Returns:
Deferred[dict]: The serialized event
dict: The serialized event
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
Expand All @@ -363,8 +362,8 @@ def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
if not event.internal_metadata.is_redacted() and (
self.experimental_msc1849_support_enabled and bundle_aggregations
):
annotations = yield self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event(
annotations = await self.store.get_aggregation_groups_for_event(event_id)
references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f"
)

Expand All @@ -378,7 +377,7 @@ def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):

edit = None
if event.type == EventTypes.Message:
edit = yield self.store.get_applicable_edit(event_id)
edit = await self.store.get_applicable_edit(event_id)

if edit:
# If there is an edit replace the content, preserving existing
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2470,7 +2470,7 @@ async def _update_context_for_auth_events(
}

current_state_ids = await context.get_current_state_ids()
current_state_ids = dict(current_state_ids)
current_state_ids = dict(current_state_ids) # type: ignore

current_state_ids.update(state_updates)

Expand Down
4 changes: 3 additions & 1 deletion synapse/replication/http/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def _serialize_payload(store, event_and_contexts, backfilled):
"""
event_payloads = []
for event, context in event_and_contexts:
serialized_context = yield context.serialize(event, store)
serialized_context = yield defer.ensureDeferred(
context.serialize(event, store)
)

event_payloads.append(
{
Expand Down
2 changes: 1 addition & 1 deletion synapse/replication/http/send_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _serialize_payload(
extra_users (list(UserID)): Any extra users to notify about event
"""

serialized_context = yield context.serialize(event, store)
serialized_context = yield defer.ensureDeferred(context.serialize(event, store))

payload = {
"event": event.get_pdu_json(),
Expand Down
Loading