Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Various improvements to the federation client #9129

Merged
merged 4 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9129.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Various improvements to the federation client.
125 changes: 67 additions & 58 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Expand All @@ -26,7 +27,6 @@
List,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -61,6 +61,9 @@
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
Expand All @@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):


class FederationClient(FederationBase):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.pdu_destination_tried = {}
self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
Expand Down Expand Up @@ -116,33 +119,32 @@ def _clear_tried_cache(self):
self.pdu_destination_tried[event_id] = destination_dict

@log_function
def make_query(
async def make_query(
self,
destination,
query_type,
args,
retry_on_dns_fail=False,
ignore_backoff=False,
):
destination: str,
query_type: str,
args: dict,
retry_on_dns_fail: bool = False,
ignore_backoff: bool = False,
) -> JsonDict:
"""Sends a federation Query to a remote homeserver of the given type
and arguments.

Args:
destination (str): Domain name of the remote homeserver
query_type (str): Category of the query type; should match the
destination: Domain name of the remote homeserver
query_type: Category of the query type; should match the
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
args: Mapping of strings to strings containing the details
of the query request.
ignore_backoff (bool): true to ignore the historical backoff data
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.

Returns:
a Awaitable which will eventually yield a JSON object from the
response
The JSON object from the response
"""
sent_queries_counter.labels(query_type).inc()

return self.transport_layer.make_query(
return await self.transport_layer.make_query(
destination,
query_type,
args,
Expand All @@ -151,42 +153,52 @@ def make_query(
)

@log_function
def query_client_keys(self, destination, content, timeout):
async def query_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Query device keys for a device hosted on a remote server.

Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
destination: Domain name of the remote homeserver
content: The query content.

Returns:
an Awaitable which will eventually yield a JSON object from the
response
The JSON object from the response
"""
sent_queries_counter.labels("client_device_keys").inc()
return self.transport_layer.query_client_keys(destination, content, timeout)
return await self.transport_layer.query_client_keys(
destination, content, timeout
)

@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
async def query_user_devices(
self, destination: str, user_id: str, timeout: int = 30000
) -> JsonDict:
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.labels("user_devices").inc()
return self.transport_layer.query_user_devices(destination, user_id, timeout)
return await self.transport_layer.query_user_devices(
destination, user_id, timeout
)

@log_function
def claim_client_keys(self, destination, content, timeout):
async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.

Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
destination: Domain name of the remote homeserver
content: The query content.

Returns:
an Awaitable which will eventually yield a JSON object from the
response
The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
return self.transport_layer.claim_client_keys(destination, content, timeout)
return await self.transport_layer.claim_client_keys(
destination, content, timeout
)

async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
Expand All @@ -195,10 +207,10 @@ async def backfill(
given destination server.

Args:
dest (str): The remote homeserver to ask.
room_id (str): The room_id to backfill.
limit (int): The maximum number of events to return.
extremities (list): our current backwards extremities, to backfill from
dest: The remote homeserver to ask.
room_id: The room_id to backfill.
limit: The maximum number of events to return.
extremities: our current backwards extremities, to backfill from
"""
logger.debug("backfill extrem=%s", extremities)

Expand Down Expand Up @@ -370,7 +382,7 @@ async def _check_sigs_and_hash_and_fetch(
for events that have failed their checks

Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)

Expand Down Expand Up @@ -418,7 +430,9 @@ async def handle_check_result(pdu: EventBase, deferred: Deferred):
else:
return [p for p in valid_pdus if p]

async def get_event_auth(self, destination, room_id, event_id):
async def get_event_auth(
self, destination: str, room_id: str, event_id: str
) -> List[EventBase]:
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)

room_version = await self.store.get_room_version(room_id)
Expand Down Expand Up @@ -700,18 +714,16 @@ async def send_request(destination) -> Dict[str, Any]:

return await self._try_destination_list("send_join", destinations, send_request)

async def _do_send_join(self, destination: str, pdu: EventBase):
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec()

try:
content = await self.transport_layer.send_join_v2(
return await self.transport_layer.send_join_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)

return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
Expand Down Expand Up @@ -769,7 +781,7 @@ async def _do_send_invite(
time_now = self._clock.time_msec()

try:
content = await self.transport_layer.send_invite_v2(
return await self.transport_layer.send_invite_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
Expand All @@ -779,7 +791,6 @@ async def _do_send_invite(
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
Expand Down Expand Up @@ -842,18 +853,16 @@ async def send_request(destination: str) -> None:
"send_leave", destinations, send_request
)

async def _do_send_leave(self, destination, pdu):
async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec()

try:
content = await self.transport_layer.send_leave_v2(
return await self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)

return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
Expand All @@ -879,15 +888,15 @@ async def _do_send_leave(self, destination, pdu):
# content.
return resp[1]

def get_public_rooms(
async def get_public_rooms(
self,
remote_server: str,
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
):
) -> JsonDict:
"""Get the list of public rooms from a remote homeserver

Args:
Expand All @@ -901,16 +910,15 @@ def get_public_rooms(
party instance

Returns:
Awaitable[Dict[str, Any]]: The response from the remote server, or None if
`remote_server` is the same as the local server_name
The response from the remote server.

Raises:
HttpResponseException: There was an exception returned from the remote server
SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
requests over federation

"""
return self.transport_layer.get_public_rooms(
return await self.transport_layer.get_public_rooms(
remote_server,
limit,
since_token,
Expand All @@ -923,7 +931,7 @@ async def get_missing_events(
self,
destination: str,
room_id: str,
earliest_events_ids: Sequence[str],
earliest_events_ids: Iterable[str],
latest_events: Iterable[EventBase],
limit: int,
min_depth: int,
Expand Down Expand Up @@ -974,7 +982,9 @@ async def get_missing_events(

return signed_events

async def forward_third_party_invite(self, destinations, room_id, event_dict):
async def forward_third_party_invite(
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
) -> None:
for destination in destinations:
if destination == self.server_name:
continue
Expand All @@ -983,7 +993,7 @@ async def forward_third_party_invite(self, destinations, room_id, event_dict):
await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
return None
return
except CodeMessageException:
raise
except Exception as e:
Expand All @@ -995,7 +1005,7 @@ async def forward_third_party_invite(self, destinations, room_id, event_dict):

async def get_room_complexity(
self, destination: str, room_id: str
) -> Optional[dict]:
) -> Optional[JsonDict]:
"""
Fetch the complexity of a remote room from another server.

Expand All @@ -1008,10 +1018,9 @@ async def get_room_complexity(
could not fetch the complexity.
"""
try:
complexity = await self.transport_layer.get_room_complexity(
return await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
return complexity
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.
Expand Down