Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Various improvements to the federation client. (#9129)
Browse files Browse the repository at this point in the history
* Type hints for `FederationClient`.
* Using `async` functions instead of returning `Awaitable` instances.
  • Loading branch information
clokep committed Jan 20, 2021
1 parent a5b9c87 commit 620ecf1
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 58 deletions.
1 change: 1 addition & 0 deletions changelog.d/9129.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Various improvements to the federation client.
125 changes: 67 additions & 58 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Expand All @@ -26,7 +27,6 @@
List,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -61,6 +61,9 @@
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
Expand All @@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):


class FederationClient(FederationBase):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.pdu_destination_tried = {}
self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
Expand Down Expand Up @@ -116,33 +119,32 @@ def _clear_tried_cache(self):
self.pdu_destination_tried[event_id] = destination_dict

@log_function
def make_query(
async def make_query(
self,
destination,
query_type,
args,
retry_on_dns_fail=False,
ignore_backoff=False,
):
destination: str,
query_type: str,
args: dict,
retry_on_dns_fail: bool = False,
ignore_backoff: bool = False,
) -> JsonDict:
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
destination (str): Domain name of the remote homeserver
query_type (str): Category of the query type; should match the
destination: Domain name of the remote homeserver
query_type: Category of the query type; should match the
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
args: Mapping of strings to strings containing the details
of the query request.
ignore_backoff (bool): true to ignore the historical backoff data
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
a Awaitable which will eventually yield a JSON object from the
response
The JSON object from the response
"""
sent_queries_counter.labels(query_type).inc()

return self.transport_layer.make_query(
return await self.transport_layer.make_query(
destination,
query_type,
args,
Expand All @@ -151,42 +153,52 @@ def make_query(
)

@log_function
def query_client_keys(self, destination, content, timeout):
async def query_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Query device keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
destination: Domain name of the remote homeserver
content: The query content.
Returns:
an Awaitable which will eventually yield a JSON object from the
response
The JSON object from the response
"""
sent_queries_counter.labels("client_device_keys").inc()
return self.transport_layer.query_client_keys(destination, content, timeout)
return await self.transport_layer.query_client_keys(
destination, content, timeout
)

@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
async def query_user_devices(
self, destination: str, user_id: str, timeout: int = 30000
) -> JsonDict:
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.labels("user_devices").inc()
return self.transport_layer.query_user_devices(destination, user_id, timeout)
return await self.transport_layer.query_user_devices(
destination, user_id, timeout
)

@log_function
def claim_client_keys(self, destination, content, timeout):
async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
destination: Domain name of the remote homeserver
content: The query content.
Returns:
an Awaitable which will eventually yield a JSON object from the
response
The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
return self.transport_layer.claim_client_keys(destination, content, timeout)
return await self.transport_layer.claim_client_keys(
destination, content, timeout
)

async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
Expand All @@ -195,10 +207,10 @@ async def backfill(
given destination server.
Args:
dest (str): The remote homeserver to ask.
room_id (str): The room_id to backfill.
limit (int): The maximum number of events to return.
extremities (list): our current backwards extremities, to backfill from
dest: The remote homeserver to ask.
room_id: The room_id to backfill.
limit: The maximum number of events to return.
extremities: our current backwards extremities, to backfill from
"""
logger.debug("backfill extrem=%s", extremities)

Expand Down Expand Up @@ -370,7 +382,7 @@ async def _check_sigs_and_hash_and_fetch(
for events that have failed their checks
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)

Expand Down Expand Up @@ -418,7 +430,9 @@ async def handle_check_result(pdu: EventBase, deferred: Deferred):
else:
return [p for p in valid_pdus if p]

async def get_event_auth(self, destination, room_id, event_id):
async def get_event_auth(
self, destination: str, room_id: str, event_id: str
) -> List[EventBase]:
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)

room_version = await self.store.get_room_version(room_id)
Expand Down Expand Up @@ -700,18 +714,16 @@ async def send_request(destination) -> Dict[str, Any]:

return await self._try_destination_list("send_join", destinations, send_request)

async def _do_send_join(self, destination: str, pdu: EventBase):
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec()

try:
content = await self.transport_layer.send_join_v2(
return await self.transport_layer.send_join_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)

return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
Expand Down Expand Up @@ -769,7 +781,7 @@ async def _do_send_invite(
time_now = self._clock.time_msec()

try:
content = await self.transport_layer.send_invite_v2(
return await self.transport_layer.send_invite_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
Expand All @@ -779,7 +791,6 @@ async def _do_send_invite(
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
Expand Down Expand Up @@ -842,18 +853,16 @@ async def send_request(destination: str) -> None:
"send_leave", destinations, send_request
)

async def _do_send_leave(self, destination, pdu):
async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec()

try:
content = await self.transport_layer.send_leave_v2(
return await self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)

return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
Expand All @@ -879,15 +888,15 @@ async def _do_send_leave(self, destination, pdu):
# content.
return resp[1]

def get_public_rooms(
async def get_public_rooms(
self,
remote_server: str,
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
):
) -> JsonDict:
"""Get the list of public rooms from a remote homeserver
Args:
Expand All @@ -901,16 +910,15 @@ def get_public_rooms(
party instance
Returns:
Awaitable[Dict[str, Any]]: The response from the remote server, or None if
`remote_server` is the same as the local server_name
The response from the remote server.
Raises:
HttpResponseException: There was an exception returned from the remote server
SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
requests over federation
"""
return self.transport_layer.get_public_rooms(
return await self.transport_layer.get_public_rooms(
remote_server,
limit,
since_token,
Expand All @@ -923,7 +931,7 @@ async def get_missing_events(
self,
destination: str,
room_id: str,
earliest_events_ids: Sequence[str],
earliest_events_ids: Iterable[str],
latest_events: Iterable[EventBase],
limit: int,
min_depth: int,
Expand Down Expand Up @@ -974,7 +982,9 @@ async def get_missing_events(

return signed_events

async def forward_third_party_invite(self, destinations, room_id, event_dict):
async def forward_third_party_invite(
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
) -> None:
for destination in destinations:
if destination == self.server_name:
continue
Expand All @@ -983,7 +993,7 @@ async def forward_third_party_invite(self, destinations, room_id, event_dict):
await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
return None
return
except CodeMessageException:
raise
except Exception as e:
Expand All @@ -995,7 +1005,7 @@ async def forward_third_party_invite(self, destinations, room_id, event_dict):

async def get_room_complexity(
self, destination: str, room_id: str
) -> Optional[dict]:
) -> Optional[JsonDict]:
"""
Fetch the complexity of a remote room from another server.
Expand All @@ -1008,10 +1018,9 @@ async def get_room_complexity(
could not fetch the complexity.
"""
try:
complexity = await self.transport_layer.get_room_complexity(
return await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
return complexity
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.
Expand Down

0 comments on commit 620ecf1

Please sign in to comment.