diff --git a/changelog.d/9129.misc b/changelog.d/9129.misc new file mode 100644 index 000000000000..7800be3e7ed6 --- /dev/null +++ b/changelog.d/9129.misc @@ -0,0 +1 @@ +Various improvements to the federation client. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 302b2f69bcdd..d330ae5dbc57 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,6 +18,7 @@ import itertools import logging from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -26,7 +27,6 @@ List, Mapping, Optional, - Sequence, Tuple, TypeVar, Union, @@ -61,6 +61,9 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"]) @@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError): class FederationClient(FederationBase): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.pdu_destination_tried = {} + self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]] self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() @@ -116,33 +119,32 @@ def _clear_tried_cache(self): self.pdu_destination_tried[event_id] = destination_dict @log_function - def make_query( + async def make_query( self, - destination, - query_type, - args, - retry_on_dns_fail=False, - ignore_backoff=False, - ): + destination: str, + query_type: str, + args: dict, + retry_on_dns_fail: bool = False, + ignore_backoff: bool = False, + ) -> JsonDict: """Sends a federation Query to a remote homeserver of the given type and arguments. Args: - destination (str): Domain name of the remote homeserver - query_type (str): Category of the query type; should match the + destination: Domain name of the remote homeserver + query_type: Category of the query type; should match the handler name used in register_query_handler(). - args (dict): Mapping of strings to strings containing the details + args: Mapping of strings to strings containing the details of the query request. - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. Returns: - a Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels(query_type).inc() - return self.transport_layer.make_query( + return await self.transport_layer.make_query( destination, query_type, args, @@ -151,42 +153,52 @@ def make_query( ) @log_function - def query_client_keys(self, destination, content, timeout): + async def query_client_keys( + self, destination: str, content: JsonDict, timeout: int + ) -> JsonDict: """Query device keys for a device hosted on a remote server. Args: - destination (str): Domain name of the remote homeserver - content (dict): The query content. + destination: Domain name of the remote homeserver + content: The query content. Returns: - an Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels("client_device_keys").inc() - return self.transport_layer.query_client_keys(destination, content, timeout) + return await self.transport_layer.query_client_keys( + destination, content, timeout + ) @log_function - def query_user_devices(self, destination, user_id, timeout=30000): + async def query_user_devices( + self, destination: str, user_id: str, timeout: int = 30000 + ) -> JsonDict: """Query the device keys for a list of user ids hosted on a remote server. """ sent_queries_counter.labels("user_devices").inc() - return self.transport_layer.query_user_devices(destination, user_id, timeout) + return await self.transport_layer.query_user_devices( + destination, user_id, timeout + ) @log_function - def claim_client_keys(self, destination, content, timeout): + async def claim_client_keys( + self, destination: str, content: JsonDict, timeout: int + ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. Args: - destination (str): Domain name of the remote homeserver - content (dict): The query content. + destination: Domain name of the remote homeserver + content: The query content. Returns: - an Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() - return self.transport_layer.claim_client_keys(destination, content, timeout) + return await self.transport_layer.claim_client_keys( + destination, content, timeout + ) async def backfill( self, dest: str, room_id: str, limit: int, extremities: Iterable[str] @@ -195,10 +207,10 @@ async def backfill( given destination server. Args: - dest (str): The remote homeserver to ask. - room_id (str): The room_id to backfill. - limit (int): The maximum number of events to return. - extremities (list): our current backwards extremities, to backfill from + dest: The remote homeserver to ask. + room_id: The room_id to backfill. + limit: The maximum number of events to return. + extremities: our current backwards extremities, to backfill from """ logger.debug("backfill extrem=%s", extremities) @@ -370,7 +382,7 @@ async def _check_sigs_and_hash_and_fetch( for events that have failed their checks Returns: - Deferred : A list of PDUs that have valid signatures and hashes. + A list of PDUs that have valid signatures and hashes. """ deferreds = self._check_sigs_and_hashes(room_version, pdus) @@ -418,7 +430,9 @@ async def handle_check_result(pdu: EventBase, deferred: Deferred): else: return [p for p in valid_pdus if p] - async def get_event_auth(self, destination, room_id, event_id): + async def get_event_auth( + self, destination: str, room_id: str, event_id: str + ) -> List[EventBase]: res = await self.transport_layer.get_event_auth(destination, room_id, event_id) room_version = await self.store.get_room_version(room_id) @@ -700,18 +714,16 @@ async def send_request(destination) -> Dict[str, Any]: return await self._try_destination_list("send_join", destinations, send_request) - async def _do_send_join(self, destination: str, pdu: EventBase): + async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_join_v2( + return await self.transport_layer.send_join_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -769,7 +781,7 @@ async def _do_send_invite( time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_invite_v2( + return await self.transport_layer.send_invite_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -779,7 +791,6 @@ async def _do_send_invite( "invite_room_state": pdu.unsigned.get("invite_room_state", []), }, ) - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -842,18 +853,16 @@ async def send_request(destination: str) -> None: "send_leave", destinations, send_request ) - async def _do_send_leave(self, destination, pdu): + async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict: time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_leave_v2( + return await self.transport_layer.send_leave_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -879,7 +888,7 @@ async def _do_send_leave(self, destination, pdu): # content. return resp[1] - def get_public_rooms( + async def get_public_rooms( self, remote_server: str, limit: Optional[int] = None, @@ -887,7 +896,7 @@ def get_public_rooms( search_filter: Optional[Dict] = None, include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, - ): + ) -> JsonDict: """Get the list of public rooms from a remote homeserver Args: @@ -901,8 +910,7 @@ def get_public_rooms( party instance Returns: - Awaitable[Dict[str, Any]]: The response from the remote server, or None if - `remote_server` is the same as the local server_name + The response from the remote server. Raises: HttpResponseException: There was an exception returned from the remote server @@ -910,7 +918,7 @@ def get_public_rooms( requests over federation """ - return self.transport_layer.get_public_rooms( + return await self.transport_layer.get_public_rooms( remote_server, limit, since_token, @@ -923,7 +931,7 @@ async def get_missing_events( self, destination: str, room_id: str, - earliest_events_ids: Sequence[str], + earliest_events_ids: Iterable[str], latest_events: Iterable[EventBase], limit: int, min_depth: int, @@ -974,7 +982,9 @@ async def get_missing_events( return signed_events - async def forward_third_party_invite(self, destinations, room_id, event_dict): + async def forward_third_party_invite( + self, destinations: Iterable[str], room_id: str, event_dict: JsonDict + ) -> None: for destination in destinations: if destination == self.server_name: continue @@ -983,7 +993,7 @@ async def forward_third_party_invite(self, destinations, room_id, event_dict): await self.transport_layer.exchange_third_party_invite( destination=destination, room_id=room_id, event_dict=event_dict ) - return None + return except CodeMessageException: raise except Exception as e: @@ -995,7 +1005,7 @@ async def forward_third_party_invite(self, destinations, room_id, event_dict): async def get_room_complexity( self, destination: str, room_id: str - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """ Fetch the complexity of a remote room from another server. @@ -1008,10 +1018,9 @@ async def get_room_complexity( could not fetch the complexity. """ try: - complexity = await self.transport_layer.get_room_complexity( + return await self.transport_layer.get_room_complexity( destination=destination, room_id=room_id ) - return complexity except CodeMessageException as e: # We didn't manage to get it -- probably a 404. We are okay if other # servers don't give it to us.