Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Move get_bundled_aggregations to relations handler. (#12237)
Browse files Browse the repository at this point in the history
The get_bundled_aggregations code is fairly high-level and uses
a lot of store methods, we move it into the handler as that seems
like a better fit.
  • Loading branch information
clokep authored Mar 18, 2022
1 parent 80e0e1f commit 8fe930c
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 157 deletions.
1 change: 1 addition & 0 deletions changelog.d/12237.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor the relations endpoints to add a `RelationsHandler`.
2 changes: 1 addition & 1 deletion synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
from . import EventBase

if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer
from synapse.storage.databases.main.relations import BundledAggregations


# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
Expand Down
5 changes: 4 additions & 1 deletion synapse/handlers/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
self._relations_handler = hs.get_relations_handler()

self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation.
Expand Down Expand Up @@ -539,7 +540,9 @@ async def get_messages(
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()

aggregations = await self.store.get_bundled_aggregations(events, user_id)
aggregations = await self._relations_handler.get_bundled_aggregations(
events, user_id
)

time_now = self.clock.time_msec()

Expand Down
151 changes: 149 additions & 2 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast

import attr
from frozendict import frozendict

from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StreamToken

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore


logger = logging.getLogger(__name__)


@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool


@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.
Some values require additional processing during serialization.
"""

annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None

def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)


class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
Expand Down Expand Up @@ -103,7 +138,7 @@ async def get_relations(
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self._main_store.get_bundled_aggregations(
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
Expand All @@ -115,3 +150,115 @@ async def get_relations(
return_value["original_event"] = original_event

return return_value

async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""

# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return None

event_id = event.event_id
room_id = event.room_id

# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations = BundledAggregations()

annotations = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id
)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)

references = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))

# Store the bundled aggregations in the event metadata for later use.
return aggregations

async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
# De-duplicate events by ID to handle the same event requested multiple times.
#
# State events do not get bundled aggregations.
events_by_id = {
event.event_id: event for event in events if not event.is_state()
}

# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}

# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result

# Fetch any edits (but not for redacted events).
edits = await self._main_store.get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
if not event.internal_metadata.is_redacted()
]
)
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit

# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._main_store.get_threads_participated(
[event_id for event_id, summary in summaries.items() if summary], user_id
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)

return results
5 changes: 3 additions & 2 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
from synapse.handlers.relations import BundledAggregations
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
Expand Down Expand Up @@ -1118,6 +1118,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_store = self.storage.state
self._relations_handler = hs.get_relations_handler()

async def get_event_context(
self,
Expand Down Expand Up @@ -1190,7 +1191,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
event = filtered[0]

# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations(
aggregations = await self._relations_handler.get_bundled_aggregations(
itertools.chain(events_before, (event,), events_after),
user.to_string(),
)
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.auth = hs.get_auth()
Expand Down Expand Up @@ -354,7 +355,7 @@ async def _search(

aggregations = None
if self._msc3666_enabled:
aggregations = await self.store.get_bundled_aggregations(
aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
Expand Down
9 changes: 6 additions & 3 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
Expand Down Expand Up @@ -269,6 +269,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler()
self._relations_handler = hs.get_relations_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
Expand Down Expand Up @@ -638,8 +639,10 @@ async def _load_filtered_recents(
# as clients will have all the necessary information.
bundled_aggregations = None
if limited or newly_joined_room:
bundled_aggregations = await self.store.get_bundled_aggregations(
recents, sync_config.user.to_string()
bundled_aggregations = (
await self._relations_handler.get_bundled_aggregations(
recents, sync_config.user.to_string()
)
)

return TimelineBatch(
Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.auth = hs.get_auth()

async def on_GET(
Expand All @@ -663,7 +664,7 @@ async def on_GET(

if event:
# Ensure there are bundled aggregations available.
aggregations = await self._store.get_bundled_aggregations(
aggregations = await self._relations_handler.get_bundled_aggregations(
[event], requester.user.to_string()
)

Expand Down
Loading

0 comments on commit 8fe930c

Please sign in to comment.