Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Move get_bundled_aggregations to relations handler. #12237

Merged
merged 4 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12237.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor the relations endpoints to add a `RelationsHandler`.
2 changes: 1 addition & 1 deletion synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
from . import EventBase

if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer
from synapse.storage.databases.main.relations import BundledAggregations


# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
Expand Down
5 changes: 4 additions & 1 deletion synapse/handlers/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
self._relations_handler = hs.get_relations_handler()

self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation.
Expand Down Expand Up @@ -539,7 +540,9 @@ async def get_messages(
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()

aggregations = await self.store.get_bundled_aggregations(events, user_id)
aggregations = await self._relations_handler.get_bundled_aggregations(
events, user_id
)

time_now = self.clock.time_msec()

Expand Down
151 changes: 149 additions & 2 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast

import attr
from frozendict import frozendict

from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StreamToken

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore


logger = logging.getLogger(__name__)


@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool


@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.

Some values require additional processing during serialization.
"""

annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None

def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)


class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
Expand Down Expand Up @@ -103,7 +138,7 @@ async def get_relations(
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self._main_store.get_bundled_aggregations(
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
Expand All @@ -115,3 +150,115 @@ async def get_relations(
return_value["original_event"] = original_event

return return_value

async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.

Note that this does not use a cache, but depends on cached methods.

Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.

Returns:
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""

# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return None

event_id = event.event_id
room_id = event.room_id

# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations = BundledAggregations()

annotations = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id
)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)

references = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))

# Store the bundled aggregations in the event metadata for later use.
return aggregations

async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.

Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.

Returns:
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
# De-duplicate events by ID to handle the same event requested multiple times.
#
# State events do not get bundled aggregations.
events_by_id = {
event.event_id: event for event in events if not event.is_state()
}

# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}

# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result

# Fetch any edits (but not for redacted events).
edits = await self._main_store.get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
if not event.internal_metadata.is_redacted()
]
)
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit

# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._main_store.get_threads_participated(
[event_id for event_id, summary in summaries.items() if summary], user_id
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)

return results
5 changes: 3 additions & 2 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
from synapse.handlers.relations import BundledAggregations
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
Expand Down Expand Up @@ -1118,6 +1118,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_store = self.storage.state
self._relations_handler = hs.get_relations_handler()

async def get_event_context(
self,
Expand Down Expand Up @@ -1190,7 +1191,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
event = filtered[0]

# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations(
aggregations = await self._relations_handler.get_bundled_aggregations(
itertools.chain(events_before, (event,), events_after),
user.to_string(),
)
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.auth = hs.get_auth()
Expand Down Expand Up @@ -354,7 +355,7 @@ async def _search(

aggregations = None
if self._msc3666_enabled:
aggregations = await self.store.get_bundled_aggregations(
aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
Expand Down
9 changes: 6 additions & 3 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
Expand Down Expand Up @@ -269,6 +269,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler()
self._relations_handler = hs.get_relations_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
Expand Down Expand Up @@ -638,8 +639,10 @@ async def _load_filtered_recents(
# as clients will have all the necessary information.
bundled_aggregations = None
if limited or newly_joined_room:
bundled_aggregations = await self.store.get_bundled_aggregations(
recents, sync_config.user.to_string()
bundled_aggregations = (
await self._relations_handler.get_bundled_aggregations(
recents, sync_config.user.to_string()
)
)

return TimelineBatch(
Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.auth = hs.get_auth()

async def on_GET(
Expand All @@ -663,7 +664,7 @@ async def on_GET(

if event:
# Ensure there are bundled aggregations available.
aggregations = await self._store.get_bundled_aggregations(
aggregations = await self._relations_handler.get_bundled_aggregations(
[event], requester.user.to_string()
)

Expand Down
Loading