Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Rename storage classes #12913

Merged
merged 13 commits into from
May 31, 2022
1 change: 1 addition & 0 deletions changelog.d/12913.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Rename storage classes.
10 changes: 5 additions & 5 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from synapse.types import JsonDict, StateMap

if TYPE_CHECKING:
from synapse.storage import Storage
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter

Expand Down Expand Up @@ -84,7 +84,7 @@ class EventContext:
incomplete state.
"""

_storage: "Storage"
_storage: "StorageControllers"
rejected: Union[Literal[False], str] = False
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
Expand All @@ -97,7 +97,7 @@ class EventContext:

@staticmethod
def with_state(
storage: "Storage",
storage: "StorageControllers",
state_group: Optional[int],
state_group_before_event: Optional[int],
state_delta_due_to_event: Optional[StateMap[str]],
Expand All @@ -117,7 +117,7 @@ def with_state(

@staticmethod
def for_outlier(
storage: "Storage",
storage: "StorageControllers",
) -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(storage=storage)
Expand Down Expand Up @@ -147,7 +147,7 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
}

@staticmethod
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext":
"""Converts a dict that was produced by `serialize` back into a
EventContext.

Expand Down
1 change: 0 additions & 1 deletion synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.handler = hs.get_federation_handler()
self.storage = hs.get_storage()
self._spam_checker = hs.get_spam_checker()
self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()
Expand Down
12 changes: 8 additions & 4 deletions synapse/handlers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
class AdminHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state

async def get_whois(self, user: UserID) -> JsonDict:
connections = []
Expand Down Expand Up @@ -197,7 +197,9 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") ->

from_key = events[-1].internal_metadata.after

events = await filter_events_for_client(self.storage, user_id, events)
events = await filter_events_for_client(
self._storage_controllers, user_id, events
)

writer.write_events(room_id, events)

Expand Down Expand Up @@ -233,7 +235,9 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") ->
for event_id in extremities:
if not event_to_unseen_prevs[event_id]:
continue
state = await self.state_storage.get_state_for_event(event_id)
state = await self._state_storage_controller.get_state_for_event(
event_id
)
writer.write_state(room_id, event_id, state)

return writer.finished()
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.state_storage = hs.get_storage().state
self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname

Expand Down Expand Up @@ -203,7 +203,7 @@ async def get_user_ids_changed(
continue

# mapping from event_id -> state_dict
prev_state_ids = await self.state_storage.get_state_ids_for_events(
prev_state_ids = await self._state_storage.get_state_ids_for_events(
event_ids
)

Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async def get_stream(
class EventHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()

async def get_event(
self,
Expand Down Expand Up @@ -177,7 +177,7 @@ async def get_event(
is_peeking = user.to_string() not in users

filtered = await filter_events_for_client(
self.storage, user.to_string(), [event], is_peeking=is_peeking
self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
)

if not filtered:
Expand Down
30 changes: 18 additions & 12 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def __init__(self, hs: "HomeServer"):
self.hs = hs

self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
Expand Down Expand Up @@ -324,7 +324,7 @@ async def _maybe_backfill_inner(
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = await filter_events_for_server(
self.storage,
self._storage_controllers,
self.server_name,
events_to_check,
redact=False,
Expand Down Expand Up @@ -660,7 +660,7 @@ async def do_knock(
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]

context = EventContext.for_outlier(self.storage)
context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
Expand Down Expand Up @@ -849,7 +849,7 @@ async def on_invite_request(
)
)

context = EventContext.for_outlier(self.storage)
context = EventContext.for_outlier(self._storage_controllers)
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
Expand Down Expand Up @@ -878,7 +878,7 @@ async def do_remotely_reject_invite(

await self.federation_client.send_leave(host_list, event)

context = EventContext.for_outlier(self.storage)
context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
Expand Down Expand Up @@ -1027,7 +1027,7 @@ async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
if event.internal_metadata.outlier:
raise NotFoundError("State not known at event %s" % (event_id,))

state_groups = await self.state_storage.get_state_groups_ids(
state_groups = await self._state_storage_controller.get_state_groups_ids(
room_id, [event_id]
)

Expand Down Expand Up @@ -1078,7 +1078,9 @@ async def on_backfill_request(
],
)

events = await filter_events_for_server(self.storage, origin, events)
events = await filter_events_for_server(
self._storage_controllers, origin, events
)

return events

Expand Down Expand Up @@ -1109,7 +1111,9 @@ async def get_persisted_pdu(
if not in_room:
raise AuthError(403, "Host not in room.")

events = await filter_events_for_server(self.storage, origin, [event])
events = await filter_events_for_server(
self._storage_controllers, origin, [event]
)
event = events[0]
return event
else:
Expand Down Expand Up @@ -1138,7 +1142,7 @@ async def on_get_missing_events(
)

missing_events = await filter_events_for_server(
self.storage, origin, missing_events
self._storage_controllers, origin, missing_events
)

return missing_events
Expand Down Expand Up @@ -1480,9 +1484,11 @@ async def _sync_partial_state_room(
# clear the lazy-loading flag.
logger.info("Updating current state for %s", room_id)
assert (
self.storage.persistence is not None
self._storage_controllers.persistence is not None
), "TODO(faster_joins): support for workers"
await self.storage.persistence.update_current_state(room_id)
await self._storage_controllers.persistence.update_current_state(
room_id
)

logger.info("Clearing partial-state flag for %s", room_id)
success = await self.store.clear_partial_state_room(room_id)
Expand Down
27 changes: 17 additions & 10 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ class FederationEventHandler:

def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
self._storage = hs.get_storage()
self._state_storage = self._storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state

self._state_handler = hs.get_state_handler()
self._event_creation_handler = hs.get_event_creation_handler()
Expand Down Expand Up @@ -535,7 +535,9 @@ async def update_state_for_partial_state_event(
)
return
await self._store.update_state_for_partial_state_event(event, context)
self._state_storage.notify_event_un_partial_stated(event.event_id)
self._state_storage_controller.notify_event_un_partial_stated(
event.event_id
)

async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
Expand Down Expand Up @@ -835,7 +837,9 @@ async def _resolve_state_at_missing_prevs(

try:
# Get the state of the events we know about
ours = await self._state_storage.get_state_groups_ids(room_id, seen)
ours = await self._state_storage_controller.get_state_groups_ids(
room_id, seen
)

# state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values())
Expand Down Expand Up @@ -1436,7 +1440,7 @@ def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
# we're not bothering about room state, so flag the event as an outlier.
event.internal_metadata.outlier = True

context = EventContext.for_outlier(self._storage)
context = EventContext.for_outlier(self._storage_controllers)
try:
validate_event_for_room_version(room_version_obj, event)
check_auth_rules_for_event(room_version_obj, event, auth)
Expand Down Expand Up @@ -1613,7 +1617,7 @@ async def _check_for_soft_fail(
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.

state_sets_d = await self._state_storage.get_state_groups_ids(
state_sets_d = await self._state_storage_controller.get_state_groups_ids(
event.room_id, extrem_ids
)
state_sets: List[StateMap[str]] = list(state_sets_d.values())
Expand Down Expand Up @@ -1885,7 +1889,7 @@ async def _update_context_for_auth_events(

# create a new state group as a delta from the existing one.
prev_group = context.state_group
state_group = await self._state_storage.store_state_group(
state_group = await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
Expand All @@ -1894,7 +1898,7 @@ async def _update_context_for_auth_events(
)

return EventContext.with_state(
storage=self._storage,
storage=self._storage_controllers,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
state_delta_due_to_event=state_updates,
Expand Down Expand Up @@ -1984,11 +1988,14 @@ async def persist_events_and_notify(
)
return result["max_stream_id"]
else:
assert self._storage.persistence
assert self._storage_controllers.persistence

# Note that this returns the events that were persisted, which may not be
# the same as were passed in if some were deduplicated due to transaction IDs.
events, max_stream_token = await self._storage.persistence.persist_events(
(
events,
max_stream_token,
) = await self._storage_controllers.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)

Expand Down
17 changes: 10 additions & 7 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(self, hs: "HomeServer"):
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state

async def snapshot_all_rooms(
self,
Expand Down Expand Up @@ -198,7 +198,8 @@ async def handle_room(event: RoomsForUser) -> None:
event.stream_ordering,
)
deferred_room_state = run_in_background(
self.state_storage.get_state_for_events, [event.event_id]
self._state_storage_controller.get_state_for_events,
[event.event_id],
).addCallback(
lambda states: cast(StateMap[EventBase], states[event.event_id])
)
Expand All @@ -218,7 +219,7 @@ async def handle_room(event: RoomsForUser) -> None:
).addErrback(unwrapFirstError)

messages = await filter_events_for_client(
self.storage, user_id, messages
self._storage_controllers, user_id, messages
)

start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
Expand Down Expand Up @@ -355,7 +356,9 @@ async def _room_initial_sync_parted(
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
room_state = await self.state_storage.get_state_for_event(member_event_id)
room_state = await self._state_storage_controller.get_state_for_event(
member_event_id
)

limit = pagin_config.limit if pagin_config else None
if limit is None:
Expand All @@ -369,7 +372,7 @@ async def _room_initial_sync_parted(
)

messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
self._storage_controllers, user_id, messages, is_peeking=is_peeking
)

start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
Expand Down Expand Up @@ -474,7 +477,7 @@ async def get_receipts() -> List[JsonDict]:
)

messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
self._storage_controllers, user_id, messages, is_peeking=is_peeking
)

start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
Expand Down
Loading