Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Reduce state pulled from DB due to sending typing and receipts over federation #12964

Merged
merged 6 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12964.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reduce the amount of state we pull from the DB.
6 changes: 5 additions & 1 deletion synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.state = hs.get_state_handler()

self._storage_controllers = hs.get_storage_controllers()

self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id

Expand Down Expand Up @@ -602,7 +604,9 @@ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
room_id = receipt.room_id

# Work out which remote servers should be poked and poke them.
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
room_id
)
domains = [
d
for d in domains_set
Expand Down
7 changes: 5 additions & 2 deletions synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class FollowerTypingHandler:

def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.config.server.server_name
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
Expand Down Expand Up @@ -131,15 +132,17 @@ async def _push_remote(self, member: RoomMember, typing: bool) -> None:
return

try:
users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()

now = self.clock.time_msec()
self.wheel_timer.insert(
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)

for domain in {get_domain_from_id(u) for u in users}:
hosts = await self._storage_controllers.state.get_current_hosts_in_room(
member.room_id
)
for domain in hosts:
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
Expand Down
4 changes: 0 additions & 4 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,6 @@ async def get_current_users_in_room(
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)

async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)

async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
) -> FrozenSet[str]:
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _invalidate_state_caches(
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
if members_changed:
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
self._attempt_to_invalidate_cache(
"get_users_in_room_with_profiles", (room_id,)
)
Expand Down
8 changes: 8 additions & 0 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
List,
Mapping,
Optional,
Set,
Tuple,
)

Expand Down Expand Up @@ -482,3 +483,10 @@ async def get_current_state_event(
room_id, StateFilter.from_types((key,))
)
return state_map.get(key)

async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state."""

await self._partial_state_room_tracker.await_full_state(room_id)

return await self.stores.main.get_current_hosts_in_room(room_id)
37 changes: 37 additions & 0 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,43 @@ async def _check_host_room_membership(

return True

@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state."""

# First we check if we already have `get_users_in_room` in the cache, as
# we can just calculate result from that
users = self.get_users_in_room.cache.get_immediate(
(room_id,), None, update_metrics=False
)
if users is not None:
return {get_domain_from_id(u) for u in users}

if isinstance(self.database_engine, Sqlite3Engine):
# If we're using SQLite then let's just always use
# `get_users_in_room` rather than funky SQL.
Comment on lines +909 to +910
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is because the SQL doesn't work on SQLite (or it would need a different dialect?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, exactly

users = await self.get_users_in_room(room_id)
return {get_domain_from_id(u) for u in users}

# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.

def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
sql = """
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This regex seems to match get_domain_from_id for the cases I can think of. 👍

FROM current_state_events
WHERE
type = 'm.room.member'
AND membership = 'join'
AND room_id = ?
"""
txn.execute(sql, (room_id,))
return {d for d, in txn}

return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
)

async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
Expand Down
14 changes: 7 additions & 7 deletions tests/federation/test_federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@

class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
state_handler=mock_state_handler,
hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)

hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
return_value=make_awaitable({"test", "host2"})
)

return hs

@override_config({"send_federation": True})
def test_send_receipts(self):
mock_send_transaction = (
Expand Down
6 changes: 4 additions & 2 deletions tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ async def check_host_in_room(room_id: str, server_name: str) -> bool:

hs.get_event_auth_handler().check_host_in_room = check_host_in_room

def get_joined_hosts_for_room(room_id: str):
async def get_current_hosts_in_room(room_id: str):
return {member.domain for member in self.room_members}

self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
hs.get_storage_controllers().state.get_current_hosts_in_room = (
get_current_hosts_in_room
)

async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
Expand Down