diff --git a/changelog.d/9402.bugfix b/changelog.d/9402.bugfix new file mode 100644 index 000000000000..7729225ba2d5 --- /dev/null +++ b/changelog.d/9402.bugfix @@ -0,0 +1 @@ +Fix a bug where a lot of unnecessary presence updates were sent when joining a room. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 97fc4d0a82b4..24ebc4b8031f 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -474,7 +474,7 @@ async def send_presence(self, states: List[UserPresenceState]): self._processing_pending_presence = False def send_presence_to_destinations( - self, states: List[UserPresenceState], destinations: List[str] + self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """Send the given presence states to the given destinations. destinations (list[str]) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fb85b19770d1..b6a9ce4f389d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -849,6 +849,9 @@ async def _handle_state_delta(self, deltas): """Process current state deltas to find new joins that need to be handled. """ + # A map of destination to a set of user state that they should receive + presence_destinations = {} # type: Dict[str, Set[UserPresenceState]] + for delta in deltas: typ = delta["type"] state_key = delta["state_key"] @@ -858,6 +861,7 @@ async def _handle_state_delta(self, deltas): logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + # Drop any event that isn't a membership join if typ != EventTypes.Member: continue @@ -880,13 +884,38 @@ async def _handle_state_delta(self, deltas): # Ignore changes to join events. continue - await self._on_user_joined_room(room_id, state_key) + # Retrieve any user presence state updates that need to be sent as a result, + # and the destinations that need to receive it + destinations, user_presence_states = await self._on_user_joined_room( + room_id, state_key + ) + + # Insert the destinations and respective updates into our destinations dict + for destination in destinations: + presence_destinations.setdefault(destination, set()).update( + user_presence_states + ) + + # Send out user presence updates for each destination + for destination, user_state_set in presence_destinations.items(): + self.federation.send_presence_to_destinations( + destinations=[destination], states=user_state_set + ) - async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: + async def _on_user_joined_room( + self, room_id: str, user_id: str + ) -> Tuple[List[str], List[UserPresenceState]]: """Called when we detect a user joining the room via the current state - delta stream. - """ + delta stream. Returns the destinations that need to be updated and the + presence updates to send to them. + + Args: + room_id: The ID of the room that the user has joined. + user_id: The ID of the user that has joined the room. + Returns: + A tuple of destinations and presence updates to send to them. + """ if self.is_mine_id(user_id): # If this is a local user then we need to send their presence # out to hosts in the room (who don't already have it) @@ -894,15 +923,15 @@ async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: # TODO: We should be able to filter the hosts down to those that # haven't previously seen the user - state = await self.current_state_for_user(user_id) - hosts = await self.state.get_current_hosts_in_room(room_id) + remote_hosts = await self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. - hosts = {host for host in hosts if host != self.server_name} + filtered_remote_hosts = [ + host for host in remote_hosts if host != self.server_name + ] - self.federation.send_presence_to_destinations( - states=[state], destinations=hosts - ) + state = await self.current_state_for_user(user_id) + return filtered_remote_hosts, [state] else: # A remote user has joined the room, so we need to: # 1. Check if this is a new server in the room @@ -915,6 +944,8 @@ async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: # TODO: Check that this is actually a new server joining the # room. + remote_host = get_domain_from_id(user_id) + users = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, users)) @@ -934,10 +965,7 @@ async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: or state.status_msg is not None ] - if states: - self.federation.send_presence_to_destinations( - states=states, destinations=[get_domain_from_id(user_id)] - ) + return [remote_host], states def should_notify(old_state, new_state): diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index be2ee26f07cf..996c6141982a 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -521,7 +521,7 @@ def test_remote_joins(self): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=["server2"], states=[expected_state] + destinations=["server2"], states={expected_state} ) # @@ -533,7 +533,7 @@ def test_remote_joins(self): self.federation_sender.send_presence.assert_not_called() self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=["server3"], states=[expected_state] + destinations=["server3"], states={expected_state} ) def test_remote_gets_presence_when_local_user_joins(self): @@ -584,8 +584,14 @@ def test_remote_gets_presence_when_local_user_joins(self): self.presence_handler.current_state_for_user("@test2:server") ) self.assertEqual(expected_state.state, PresenceState.ONLINE) - self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations={"server2", "server3"}, states=[expected_state] + self.assertEqual( + self.federation_sender.send_presence_to_destinations.call_count, 2 + ) + self.federation_sender.send_presence_to_destinations.assert_any_call( + destinations=["server3"], states={expected_state} + ) + self.federation_sender.send_presence_to_destinations.assert_any_call( + destinations=["server2"], states={expected_state} ) def _add_new_user(self, room_id, user_id):